Expand HMAC support in Salt; require an identifier be passed in to specify type but allow generation with and without. Add a StaticSalt ID for testing functions. Fix bugs; unit tests pass.
This commit is contained in:
parent
b655f6b858
commit
5dde76fa1c
|
@ -32,4 +32,4 @@ type BackendConfig struct {
|
|||
}
|
||||
|
||||
// Factory is the factory function to create an audit backend.
|
||||
type Factory func(BackendConfig) (Backend, error)
|
||||
type Factory func(*BackendConfig) (Backend, error)
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
//
|
||||
// The structure is modified in-place.
|
||||
func Hash(salter *salt.Salt, raw interface{}) error {
|
||||
fn := salter.GetHMAC
|
||||
fn := salter.GetIdentifiedHMAC
|
||||
|
||||
switch s := raw.(type) {
|
||||
case *logical.Auth:
|
||||
|
@ -86,17 +86,6 @@ func HashStructure(s interface{}, cb HashCallback) (interface{}, error) {
|
|||
// a value.
|
||||
type HashCallback func(string) string
|
||||
|
||||
// HashSHA1 returns a HashCallback that hashes data with SHA1 and
|
||||
// with an optional salt. If salt is a blank string, no salt is used.
|
||||
/*
|
||||
func HashSHA1(salt string) HashCallback {
|
||||
return func(v string) (string, error) {
|
||||
hashed := sha1.Sum([]byte(v + salt))
|
||||
return "sha1:" + hex.EncodeToString(hashed[:]), nil
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// hashWalker implements interfaces for the reflectwalk package
|
||||
// (github.com/mitchellh/reflectwalk) that can be used to automatically
|
||||
// replace primitives with a hashed value.
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
@ -88,7 +90,7 @@ func TestHash(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
&logical.Auth{ClientToken: "foo"},
|
||||
&logical.Auth{ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"},
|
||||
&logical.Auth{ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a"},
|
||||
},
|
||||
{
|
||||
&logical.Request{
|
||||
|
@ -98,7 +100,7 @@ func TestHash(t *testing.T) {
|
|||
},
|
||||
&logical.Request{
|
||||
Data: map[string]interface{}{
|
||||
"foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d",
|
||||
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -110,7 +112,7 @@ func TestHash(t *testing.T) {
|
|||
},
|
||||
&logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d",
|
||||
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -133,14 +135,22 @@ func TestHash(t *testing.T) {
|
|||
IssueTime: now,
|
||||
},
|
||||
|
||||
ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33",
|
||||
ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
localSalt, err := salt.NewSalt(nil, &salt.Config{
|
||||
HMAC: sha256.New,
|
||||
HMACType: "hmac-sha256",
|
||||
StaticSalt: "foo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error instantiating salt: %s", err)
|
||||
}
|
||||
for _, tc := range cases {
|
||||
input := fmt.Sprintf("%#v", tc.Input)
|
||||
if err := Hash(tc.Input); err != nil {
|
||||
if err := Hash(localSalt, tc.Input); err != nil {
|
||||
t.Fatalf("err: %s\n\n%s", err, input)
|
||||
}
|
||||
if !reflect.DeepEqual(tc.Input, tc.Output) {
|
||||
|
@ -176,8 +186,8 @@ func TestHashWalker(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
output, err := HashStructure(tc.Input, func(string) (string, error) {
|
||||
return replaceText, nil
|
||||
output, err := HashStructure(tc.Input, func(string) string {
|
||||
return replaceText
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s\n\n%#v", err, tc.Input)
|
||||
|
@ -187,14 +197,3 @@ func TestHashWalker(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashSHA1(t *testing.T) {
|
||||
fn := HashSHA1("")
|
||||
result, err := fn("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if result != "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33" {
|
||||
t.Fatalf("bad: %#v", result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
||||
func Factory(conf audit.BackendConfig) (audit.Backend, error) {
|
||||
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if conf.Salt == nil {
|
||||
return nil, fmt.Errorf("Nil salt passed in")
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
||||
func Factory(conf audit.BackendConfig) (audit.Backend, error) {
|
||||
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if conf.Salt == nil {
|
||||
return nil, fmt.Errorf("Nil salt passed in")
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ type Salt struct {
|
|||
salt string
|
||||
generated bool
|
||||
hmac hash.Hash
|
||||
hmacType string
|
||||
}
|
||||
|
||||
type HashFunc func([]byte) []byte
|
||||
|
@ -44,6 +45,14 @@ type Config struct {
|
|||
// HMAC allows specification of a hash function to use for
|
||||
// the HMAC helpers
|
||||
HMAC func() hash.Hash
|
||||
|
||||
// String prepended to HMAC strings for identification.
|
||||
// Required if using HMAC
|
||||
HMACType string
|
||||
|
||||
// A static string to use if set. If not set, one will be
|
||||
// generated and persisted. This value will *not* be persisted.
|
||||
StaticSalt string
|
||||
}
|
||||
|
||||
// NewSalt creates a new salt based on the configuration
|
||||
|
@ -64,35 +73,49 @@ func NewSalt(view logical.Storage, config *Config) (*Salt, error) {
|
|||
config: config,
|
||||
}
|
||||
|
||||
// Look for the salt
|
||||
raw, err := view.Get(config.Location)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read salt: %v", err)
|
||||
}
|
||||
var raw *logical.StorageEntry
|
||||
var err error
|
||||
if config.StaticSalt != "" {
|
||||
s.salt = config.StaticSalt
|
||||
} else {
|
||||
if view != nil {
|
||||
// Look for the salt
|
||||
raw, err = view.Get(config.Location)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read salt: %v", err)
|
||||
}
|
||||
|
||||
// Restore the salt if it exists
|
||||
if raw != nil {
|
||||
s.salt = string(raw.Value)
|
||||
// Restore the salt if it exists
|
||||
if raw != nil {
|
||||
s.salt = string(raw.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new salt if necessary
|
||||
if s.salt == "" {
|
||||
s.salt = uuid.GenerateUUID()
|
||||
s.generated = true
|
||||
raw = &logical.StorageEntry{
|
||||
Key: config.Location,
|
||||
Value: []byte(s.salt),
|
||||
}
|
||||
if err := view.Put(raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist salt: %v", err)
|
||||
if view != nil {
|
||||
raw = &logical.StorageEntry{
|
||||
Key: config.Location,
|
||||
Value: []byte(s.salt),
|
||||
}
|
||||
if err := view.Put(raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist salt: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.HMAC != nil {
|
||||
if len(config.HMACType) == 0 {
|
||||
return nil, fmt.Errorf("HMACType must be defined")
|
||||
}
|
||||
s.hmac = hmac.New(config.HMAC, []byte(s.salt))
|
||||
if s.hmac == nil {
|
||||
return nil, fmt.Errorf("failed to instantiate HMAC function")
|
||||
}
|
||||
s.hmacType = config.HMACType
|
||||
}
|
||||
|
||||
return s, nil
|
||||
|
@ -104,7 +127,7 @@ func (s *Salt) SaltID(id string) string {
|
|||
return SaltID(s.salt, id, s.config.HashFunc)
|
||||
}
|
||||
|
||||
// SaltIDandHMAC is used to apply a salt and hash function to an ID to make sure
|
||||
// GetHMAC is used to apply a salt and hash function to an ID to make sure
|
||||
// it is not reversible, with an additional HMAC
|
||||
func (s *Salt) GetHMAC(id string) string {
|
||||
if s.hmac == nil {
|
||||
|
@ -112,7 +135,19 @@ func (s *Salt) GetHMAC(id string) string {
|
|||
}
|
||||
s.hmac.Reset()
|
||||
s.hmac.Write([]byte(id))
|
||||
return string(s.hmac.Sum(nil))
|
||||
return hex.EncodeToString(s.hmac.Sum(nil))
|
||||
}
|
||||
|
||||
// GetIdentifiedHMAC is used to apply a salt and hash function to an ID to make sure
|
||||
// it is not reversible, with an additional HMAC, and ID prepended
|
||||
func (s *Salt) GetIdentifiedHMAC(id string) string {
|
||||
if s.hmac == nil {
|
||||
return ""
|
||||
}
|
||||
s.hmac.Reset()
|
||||
s.hmac.Write([]byte(id))
|
||||
|
||||
return s.hmacType + ":" + hex.EncodeToString(s.hmac.Sum(nil))
|
||||
}
|
||||
|
||||
// DidGenerate returns if the underlying salt value was generated
|
||||
|
|
|
@ -209,11 +209,12 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s
|
|||
salter, err := salt.NewSalt(view, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
HMAC: sha256.New,
|
||||
HMACType: "hmac-sha256",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[ERR] core: unable to generate salt: %v", err)
|
||||
}
|
||||
return f(audit.BackendConfig{
|
||||
return f(&audit.BackendConfig{
|
||||
Salt: salter,
|
||||
Config: conf,
|
||||
})
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
)
|
||||
|
||||
type NoopAudit struct {
|
||||
Config *audit.BackendConfig
|
||||
ReqErr error
|
||||
ReqAuth []*logical.Auth
|
||||
Req []*logical.Request
|
||||
|
@ -44,8 +45,10 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical
|
|||
|
||||
func TestCore_EnableAudit(t *testing.T) {
|
||||
c, key, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
me := &MountEntry{
|
||||
|
@ -66,8 +69,10 @@ func TestCore_EnableAudit(t *testing.T) {
|
|||
AuditBackends: make(map[string]audit.Factory),
|
||||
DisableMlock: true,
|
||||
}
|
||||
conf.AuditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
conf.AuditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
c2, err := NewCore(conf)
|
||||
if err != nil {
|
||||
|
@ -94,8 +99,10 @@ func TestCore_EnableAudit(t *testing.T) {
|
|||
|
||||
func TestCore_DisableAudit(t *testing.T) {
|
||||
c, key, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := c.disableAudit("foo")
|
||||
|
|
|
@ -841,7 +841,10 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) {
|
|||
// Create a noop audit backend
|
||||
noop := &NoopAudit{}
|
||||
c, _, root := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
noop = &NoopAudit{
|
||||
Config: config,
|
||||
}
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
|
@ -920,7 +923,10 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) {
|
|||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return noopBack, nil
|
||||
}
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
noop = &NoopAudit{
|
||||
Config: config,
|
||||
}
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -521,8 +521,10 @@ func TestSystemBackend_policyCRUD(t *testing.T) {
|
|||
|
||||
func TestSystemBackend_enableAudit(t *testing.T) {
|
||||
c, b, _ := testCoreSystemBackend(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")
|
||||
|
@ -552,8 +554,10 @@ func TestSystemBackend_enableAudit_invalid(t *testing.T) {
|
|||
|
||||
func TestSystemBackend_auditTable(t *testing.T) {
|
||||
c, b, _ := testCoreSystemBackend(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")
|
||||
|
@ -586,8 +590,10 @@ func TestSystemBackend_auditTable(t *testing.T) {
|
|||
|
||||
func TestSystemBackend_disableAudit(t *testing.T) {
|
||||
c, b, _ := testCoreSystemBackend(t)
|
||||
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
|
||||
return &NoopAudit{}, nil
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")
|
||||
|
|
|
@ -57,8 +57,10 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
|
|||
// TestCore returns a pure in-memory, uninitialized core for testing.
|
||||
func TestCore(t *testing.T) *Core {
|
||||
noopAudits := map[string]audit.Factory{
|
||||
"noop": func(audit.BackendConfig) (audit.Backend, error) {
|
||||
return new(noopAudit), nil
|
||||
"noop": func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &noopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
noopBackends := make(map[string]logical.Factory)
|
||||
|
@ -240,7 +242,9 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type noopAudit struct{}
|
||||
type noopAudit struct {
|
||||
Config *audit.BackendConfig
|
||||
}
|
||||
|
||||
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue