385 lines
9.5 KiB
Go
385 lines
9.5 KiB
Go
package quotas
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
|
|
"github.com/hashicorp/vault/builtin/credential/userpass"
|
|
"github.com/hashicorp/vault/builtin/logical/pki"
|
|
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
|
|
"github.com/hashicorp/vault/vault"
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
const (
|
|
testLookupOnlyPolicy = `
|
|
path "/auth/token/lookup" {
|
|
capabilities = [ "create", "update"]
|
|
}
|
|
`
|
|
)
|
|
|
|
var (
|
|
coreConfig = &vault.CoreConfig{
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"pki": pki.Factory,
|
|
},
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"userpass": userpass.Factory,
|
|
},
|
|
}
|
|
)
|
|
|
|
func setupMounts(t *testing.T, client *api.Client) {
|
|
t.Helper()
|
|
|
|
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
|
Type: "userpass",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
|
"password": "bar",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = client.Sys().Mount("pki", &api.MountInput{
|
|
Type: "pki",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
|
"common_name": "testvault.com",
|
|
"ttl": "200h",
|
|
"ip_sans": "127.0.0.1",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
|
"require_cn": false,
|
|
"allowed_domains": "testvault.com",
|
|
"allow_subdomains": true,
|
|
"max_ttl": "2h",
|
|
"generate_lease": true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
}
|
|
|
|
func teardownMounts(t *testing.T, client *api.Client) {
|
|
t.Helper()
|
|
if err := client.Sys().Unmount("pki"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := client.Sys().DisableAuth("userpass"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) {
|
|
numSuccess := atomic.NewInt32(0)
|
|
numFail := atomic.NewInt32(0)
|
|
|
|
start := time.Now()
|
|
end := start.Add(d)
|
|
for time.Now().Before(end) {
|
|
reqFunc(numSuccess, numFail)
|
|
}
|
|
|
|
return numSuccess.Load(), numFail.Load(), time.Since(start)
|
|
}
|
|
|
|
func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error {
|
|
ticker := time.Tick(tick)
|
|
timeout := time.After(to)
|
|
|
|
// wait for the resource to be removed
|
|
for {
|
|
select {
|
|
case <-timeout:
|
|
return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path)
|
|
|
|
case <-ticker:
|
|
resp, err := c.Logical().Read(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp == nil {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestQuotas_RateLimitQuota_Mount(t *testing.T) {
|
|
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
|
cluster := vault.NewTestCluster(t, conf, opts)
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
core := cluster.Cores[0].Core
|
|
client := cluster.Cores[0].Client
|
|
vault.TestWaitActive(t, core)
|
|
|
|
err := client.Sys().Mount("pki", &api.MountInput{
|
|
Type: "pki",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
|
"common_name": "testvault.com",
|
|
"ttl": "200h",
|
|
"ip_sans": "127.0.0.1",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
|
"require_cn": false,
|
|
"allowed_domains": "testvault.com",
|
|
"allow_subdomains": true,
|
|
"max_ttl": "2h",
|
|
"generate_lease": true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
|
_, err := client.Logical().Read("pki/cert/ca_chain")
|
|
|
|
if err != nil {
|
|
numFail.Add(1)
|
|
} else {
|
|
numSuccess.Add(1)
|
|
}
|
|
}
|
|
|
|
// Create a rate limit quota with a low RPS of 7.7, which means we can process
|
|
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
|
|
// by a refill rate of 7.7 per-second.
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
|
"rate": 7.7,
|
|
"burst": 8,
|
|
"path": "pki/",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
|
|
|
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
|
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
|
|
|
// ensure there were some failed requests
|
|
if numFail == 0 {
|
|
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
|
}
|
|
|
|
// ensure that we should never get more requests than allowed
|
|
if want := int32(ideal + 1); numSuccess > want {
|
|
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
|
}
|
|
|
|
// update the rate limit quota with a high RPS such that no requests should fail
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
|
"rate": 1000.0,
|
|
"burst": 3000,
|
|
"path": "pki/",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
|
|
if numFail > 0 {
|
|
t.Fatalf("unexpected number of failed requests: %d", numFail)
|
|
}
|
|
}
|
|
|
|
func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) {
|
|
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
|
cluster := vault.NewTestCluster(t, conf, opts)
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
core := cluster.Cores[0].Core
|
|
client := cluster.Cores[0].Client
|
|
|
|
vault.TestWaitActive(t, core)
|
|
|
|
// create PKI mount
|
|
err := client.Sys().Mount("pki", &api.MountInput{
|
|
Type: "pki",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
|
"common_name": "testvault.com",
|
|
"ttl": "200h",
|
|
"ip_sans": "127.0.0.1",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
|
"require_cn": false,
|
|
"allowed_domains": "testvault.com",
|
|
"allow_subdomains": true,
|
|
"max_ttl": "2h",
|
|
"generate_lease": true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// create a root rate limit quota
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{
|
|
"name": "root-rlq",
|
|
"rate": 14.7,
|
|
"burst": 15,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// create a mount rate limit quota with a lower RPS than the root rate limit quota
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{
|
|
"name": "mount-rlq",
|
|
"rate": 7.7,
|
|
"burst": 8,
|
|
"path": "pki/",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// ensure mount rate limit quota takes precedence over root rate limit quota
|
|
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
|
_, err := client.Logical().Read("pki/cert/ca_chain")
|
|
|
|
if err != nil {
|
|
numFail.Add(1)
|
|
} else {
|
|
numSuccess.Add(1)
|
|
}
|
|
}
|
|
|
|
// ensure mount rate limit quota takes precedence over root rate limit quota
|
|
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
|
|
|
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
|
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
|
|
|
// ensure there were some failed requests
|
|
if numFail == 0 {
|
|
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
|
}
|
|
|
|
// ensure that we should never get more requests than allowed
|
|
if want := int32(ideal + 1); numSuccess > want {
|
|
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
|
}
|
|
}
|
|
|
|
func TestQuotas_RateLimitQuota(t *testing.T) {
|
|
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
|
cluster := vault.NewTestCluster(t, conf, opts)
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
core := cluster.Cores[0].Core
|
|
client := cluster.Cores[0].Client
|
|
|
|
vault.TestWaitActive(t, core)
|
|
|
|
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
|
Type: "userpass",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
|
"password": "bar",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Create a rate limit quota with a low RPS of 7.7, which means we can process
|
|
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
|
|
// by a refill rate of 7.7 per-second.
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
|
"rate": 7.7,
|
|
"burst": 8,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
|
_, err := client.Logical().Read("sys/quotas/rate-limit/rlq")
|
|
|
|
if err != nil {
|
|
numFail.Add(1)
|
|
} else {
|
|
numSuccess.Add(1)
|
|
}
|
|
}
|
|
|
|
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
|
|
|
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
|
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
|
|
|
// ensure there were some failed requests
|
|
if numFail == 0 {
|
|
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
|
}
|
|
|
|
// ensure that we should never get more requests than allowed
|
|
if want := int32(ideal + 1); numSuccess > want {
|
|
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
|
}
|
|
|
|
// allow time (1s) for rate limit to refill before updating the quota
|
|
time.Sleep(time.Second)
|
|
|
|
// update the rate limit quota with a high RPS such that no requests should fail
|
|
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
|
"rate": 1000.0,
|
|
"burst": 3000,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
|
|
if numFail > 0 {
|
|
t.Fatalf("unexpected number of failed requests: %d", numFail)
|
|
}
|
|
}
|