open-vault/vault/request_handling_test.go

300 lines
6.8 KiB
Go

package vault
import (
"strings"
"testing"
"time"
"github.com/armon/go-metrics"
uuid "github.com/hashicorp/go-uuid"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)
func TestRequestHandling_Wrapping(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
core.logicalBackends["kv"] = PassthroughBackendFactory
meUUID, _ := uuid.GenerateUUID()
err := core.mount(namespace.RootContext(nil), &MountEntry{
Table: mountTableType,
UUID: meUUID,
Path: "wraptest",
Type: "kv",
})
if err != nil {
t.Fatalf("err: %v", err)
}
// No duration specified
req := &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"zip": "zap",
},
}
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.ReadOperation,
WrapInfo: &logical.RequestWrapInfo{
TTL: time.Duration(15 * time.Second),
},
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
}
func TestRequestHandling_LoginWrapping(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
if err := core.loadMounts(namespace.RootContext(nil)); err != nil {
t.Fatalf("err: %v", err)
}
core.credentialBackends["userpass"] = credUserpass.Factory
// No duration specified
req := &logical.Request{
Path: "sys/auth/userpass",
ClientToken: root,
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": "userpass",
},
Connection: &logical.Connection{},
}
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req.Path = "auth/userpass/users/test"
req.Data = map[string]interface{}{
"password": "foo",
"policies": "default",
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req = &logical.Request{
Path: "auth/userpass/login/test",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"password": "foo",
},
Connection: &logical.Connection{},
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo != nil {
t.Fatalf("bad: %#v", resp)
}
req = &logical.Request{
Path: "auth/userpass/login/test",
Operation: logical.UpdateOperation,
WrapInfo: &logical.RequestWrapInfo{
TTL: time.Duration(15 * time.Second),
},
Data: map[string]interface{}{
"password": "foo",
},
Connection: &logical.Connection{},
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
}
func labelsMatch(actual, expected map[string]string) bool {
for expected_label, expected_val := range expected {
if v, ok := actual[expected_label]; ok {
if v != expected_val {
return false
}
} else {
return false
}
}
return true
}
func checkCounter(t *testing.T, inmemSink *metrics.InmemSink, keyPrefix string, expectedLabels map[string]string) {
t.Helper()
intervals := inmemSink.Data()
if len(intervals) > 1 {
t.Skip("Detected interval crossing.")
}
var counter *metrics.SampledValue = nil
var labels map[string]string
for _, c := range intervals[0].Counters {
if !strings.HasPrefix(c.Name, keyPrefix) {
continue
}
counter = &c
labels = make(map[string]string)
for _, l := range counter.Labels {
labels[l.Name] = l.Value
}
// Distinguish between different label sets
if labelsMatch(labels, expectedLabels) {
break
}
}
if counter == nil {
t.Fatalf("No %q counter found with matching labels", keyPrefix)
}
if !labelsMatch(labels, expectedLabels) {
t.Errorf("No matching label set, found %v", labels)
}
if counter.Count != 1 {
t.Errorf("Counter number of samples %v is not 1.", counter.Count)
}
if counter.Sum != 1.0 {
t.Errorf("Counter sum %v is not 1.", counter.Sum)
}
}
func TestRequestHandling_LoginMetric(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
if err := core.loadMounts(namespace.RootContext(nil)); err != nil {
t.Fatalf("err: %v", err)
}
core.credentialBackends["userpass"] = credUserpass.Factory
inmemSink := metrics.NewInmemSink(
1000000*time.Hour,
2000000*time.Hour)
core.metricSink = &metricsutil.ClusterMetricSink{
ClusterName: "test-cluster",
Sink: inmemSink,
}
// Setup mount
req := &logical.Request{
Path: "sys/auth/userpass",
ClientToken: root,
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": "userpass",
},
Connection: &logical.Connection{},
}
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Create user
req.Path = "auth/userpass/users/test"
req.Data = map[string]interface{}{
"password": "foo",
"policies": "default",
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Login with response wrapping
req = &logical.Request{
Path: "auth/userpass/login/test",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"password": "foo",
},
WrapInfo: &logical.RequestWrapInfo{
TTL: time.Duration(15 * time.Second),
},
Connection: &logical.Connection{},
}
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
// There should be two counters
checkCounter(t, inmemSink, "token.creation",
map[string]string{
"cluster": "test-cluster",
"namespace": "root",
"auth_method": "userpass",
"mount_point": "auth/userpass/",
"creation_ttl": "+Inf",
"token_type": "service",
},
)
checkCounter(t, inmemSink, "token.creation",
map[string]string{
"cluster": "test-cluster",
"namespace": "root",
"auth_method": "response_wrapping",
"mount_point": "auth/userpass/",
"creation_ttl": "1m",
"token_type": "service",
},
)
}