open-vault/vault/external_tests/quotas/quotas_test.go
2020-06-26 17:13:16 -04:00

385 lines
9.5 KiB
Go

package quotas
import (
"fmt"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
"github.com/hashicorp/vault/vault"
"go.uber.org/atomic"
)
const (
testLookupOnlyPolicy = `
path "/auth/token/lookup" {
capabilities = [ "create", "update"]
}
`
)
var (
coreConfig = &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": pki.Factory,
},
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
)
func setupMounts(t *testing.T, client *api.Client) {
t.Helper()
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
}
func teardownMounts(t *testing.T, client *api.Client) {
t.Helper()
if err := client.Sys().Unmount("pki"); err != nil {
t.Fatal(err)
}
if err := client.Sys().DisableAuth("userpass"); err != nil {
t.Fatal(err)
}
}
func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) {
numSuccess := atomic.NewInt32(0)
numFail := atomic.NewInt32(0)
start := time.Now()
end := start.Add(d)
for time.Now().Before(end) {
reqFunc(numSuccess, numFail)
}
return numSuccess.Load(), numFail.Load(), time.Since(start)
}
func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error {
ticker := time.Tick(tick)
timeout := time.After(to)
// wait for the resource to be removed
for {
select {
case <-timeout:
return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path)
case <-ticker:
resp, err := c.Logical().Read(path)
if err != nil {
return err
}
if resp == nil {
return nil
}
}
}
}
func TestQuotas_RateLimitQuota_Mount(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
err := client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("pki/cert/ca_chain")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
// Create a rate limit quota with a low RPS of 7.7, which means we can process
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
// by a refill rate of 7.7 per-second.
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 7.7,
"burst": 8,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
// update the rate limit quota with a high RPS such that no requests should fail
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 1000.0,
"burst": 3000,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
if numFail > 0 {
t.Fatalf("unexpected number of failed requests: %d", numFail)
}
}
func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
// create PKI mount
err := client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
// create a root rate limit quota
_, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{
"name": "root-rlq",
"rate": 14.7,
"burst": 15,
})
if err != nil {
t.Fatal(err)
}
// create a mount rate limit quota with a lower RPS than the root rate limit quota
_, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{
"name": "mount-rlq",
"rate": 7.7,
"burst": 8,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
// ensure mount rate limit quota takes precedence over root rate limit quota
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("pki/cert/ca_chain")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
// ensure mount rate limit quota takes precedence over root rate limit quota
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
}
func TestQuotas_RateLimitQuota(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
// Create a rate limit quota with a low RPS of 7.7, which means we can process
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
// by a refill rate of 7.7 per-second.
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 7.7,
"burst": 8,
})
if err != nil {
t.Fatal(err)
}
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("sys/quotas/rate-limit/rlq")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
// allow time (1s) for rate limit to refill before updating the quota
time.Sleep(time.Second)
// update the rate limit quota with a high RPS such that no requests should fail
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 1000.0,
"burst": 3000,
})
if err != nil {
t.Fatal(err)
}
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
if numFail > 0 {
t.Fatalf("unexpected number of failed requests: %d", numFail)
}
}