package vault import ( "strings" "testing" "time" "github.com/armon/go-metrics" uuid "github.com/hashicorp/go-uuid" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) func TestRequestHandling_Wrapping(t *testing.T) { core, _, root := TestCoreUnsealed(t) core.logicalBackends["kv"] = PassthroughBackendFactory meUUID, _ := uuid.GenerateUUID() err := core.mount(namespace.RootContext(nil), &MountEntry{ Table: mountTableType, UUID: meUUID, Path: "wraptest", Type: "kv", }) if err != nil { t.Fatalf("err: %v", err) } // No duration specified req := &logical.Request{ Path: "wraptest/foo", ClientToken: root, Operation: logical.UpdateOperation, Data: map[string]interface{}{ "zip": "zap", }, } resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } req = &logical.Request{ Path: "wraptest/foo", ClientToken: root, Operation: logical.ReadOperation, WrapInfo: &logical.RequestWrapInfo{ TTL: time.Duration(15 * time.Second), }, } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp == nil { t.Fatalf("bad: %v", resp) } if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) { t.Fatalf("bad: %#v", resp) } } func TestRequestHandling_LoginWrapping(t *testing.T) { core, _, root := TestCoreUnsealed(t) if err := core.loadMounts(namespace.RootContext(nil)); err != nil { t.Fatalf("err: %v", err) } core.credentialBackends["userpass"] = credUserpass.Factory // No duration specified req := &logical.Request{ Path: "sys/auth/userpass", ClientToken: root, Operation: logical.UpdateOperation, Data: map[string]interface{}{ "type": "userpass", }, Connection: &logical.Connection{}, } resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } req.Path = "auth/userpass/users/test" req.Data = map[string]interface{}{ "password": "foo", "policies": "default", } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } req = &logical.Request{ Path: "auth/userpass/login/test", Operation: logical.UpdateOperation, Data: map[string]interface{}{ "password": "foo", }, Connection: &logical.Connection{}, } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp == nil { t.Fatalf("bad: %v", resp) } if resp.WrapInfo != nil { t.Fatalf("bad: %#v", resp) } req = &logical.Request{ Path: "auth/userpass/login/test", Operation: logical.UpdateOperation, WrapInfo: &logical.RequestWrapInfo{ TTL: time.Duration(15 * time.Second), }, Data: map[string]interface{}{ "password": "foo", }, Connection: &logical.Connection{}, } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp == nil { t.Fatalf("bad: %v", resp) } if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) { t.Fatalf("bad: %#v", resp) } } func labelsMatch(actual, expected map[string]string) bool { for expected_label, expected_val := range expected { if v, ok := actual[expected_label]; ok { if v != expected_val { return false } } else { return false } } return true } func checkCounter(t *testing.T, inmemSink *metrics.InmemSink, keyPrefix string, expectedLabels map[string]string) { t.Helper() intervals := inmemSink.Data() if len(intervals) > 1 { t.Skip("Detected interval crossing.") } var counter *metrics.SampledValue = nil var labels map[string]string for _, c := range intervals[0].Counters { if !strings.HasPrefix(c.Name, keyPrefix) { continue } counter = &c labels = make(map[string]string) for _, l := range counter.Labels { labels[l.Name] = l.Value } // Distinguish between different label sets if labelsMatch(labels, expectedLabels) { break } } if counter == nil { t.Fatalf("No %q counter found with matching labels", keyPrefix) } if !labelsMatch(labels, expectedLabels) { t.Errorf("No matching label set, found %v", labels) } if counter.Count != 1 { t.Errorf("Counter number of samples %v is not 1.", counter.Count) } if counter.Sum != 1.0 { t.Errorf("Counter sum %v is not 1.", counter.Sum) } } func TestRequestHandling_LoginMetric(t *testing.T) { core := TestCore(t) // Replace sink before unseal! inmemSink := metrics.NewInmemSink( 1000000*time.Hour, 2000000*time.Hour) core.metricSink = metricsutil.NewClusterMetricSink("test-cluster", inmemSink) _, _, root := testCoreUnsealed(t, core) if err := core.loadMounts(namespace.RootContext(nil)); err != nil { t.Fatalf("err: %v", err) } core.credentialBackends["userpass"] = credUserpass.Factory // Setup mount req := &logical.Request{ Path: "sys/auth/userpass", ClientToken: root, Operation: logical.UpdateOperation, Data: map[string]interface{}{ "type": "userpass", }, Connection: &logical.Connection{}, } resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } // Create user req.Path = "auth/userpass/users/test" req.Data = map[string]interface{}{ "password": "foo", "policies": "default", } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } // Login with response wrapping req = &logical.Request{ Path: "auth/userpass/login/test", Operation: logical.UpdateOperation, Data: map[string]interface{}{ "password": "foo", }, WrapInfo: &logical.RequestWrapInfo{ TTL: time.Duration(15 * time.Second), }, Connection: &logical.Connection{}, } resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp == nil { t.Fatalf("bad: %v", resp) } // There should be two counters checkCounter(t, inmemSink, "token.creation", map[string]string{ "cluster": "test-cluster", "namespace": "root", "auth_method": "userpass", "mount_point": "auth/userpass/", "creation_ttl": "+Inf", "token_type": "service", }, ) checkCounter(t, inmemSink, "token.creation", map[string]string{ "cluster": "test-cluster", "namespace": "root", "auth_method": "response_wrapping", "mount_point": "auth/userpass/", "creation_ttl": "1m", "token_type": "service", }, ) } func TestRequestHandling_SecretLeaseMetric(t *testing.T) { core, _, root := TestCoreUnsealed(t) inmemSink := metrics.NewInmemSink( 1000000*time.Hour, 2000000*time.Hour) core.metricSink = metricsutil.NewClusterMetricSink("test-cluster", inmemSink) // Create a key with a lease req := logical.TestRequest(t, logical.UpdateOperation, "secret/foo") req.Data["foo"] = "bar" req.ClientToken = root resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %#v", resp) } // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root req.SetTokenEntry(&logical.TokenEntry{ID: root, NamespaceID: "root", Policies: []string{"root"}}) resp, err = core.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } if resp == nil || resp.Secret == nil || resp.Secret.LeaseID == "" { t.Fatalf("bad: %#v", resp) } checkCounter(t, inmemSink, "secret.lease.creation", map[string]string{ "cluster": "test-cluster", "namespace": "root", "secret_engine": "kv", "mount_point": "secret/", "creation_ttl": "+Inf", }, ) }