Vault Agent Cache (#6220)
* vault-agent-cache: squashed 250+ commits * Add proper token revocation validations to the tests * Add more test cases * Avoid leaking by not closing request/response bodies; add comments * Fix revoke orphan use case; update tests * Add CLI test for making request over unix socket * agent/cache: remove namespace-related tests * Strip-off the auto-auth token from the lookup response * Output listener details along with configuration * Add scheme to API address output * leasecache: use IndexNameLease for prefix lease revocations * Make CLI accept the fully qualified unix address * export VAULT_AGENT_ADDR=unix://path/to/socket * unix:/ to unix://
This commit is contained in:
parent
9044173d6e
commit
feb235d5f8
|
@ -48,7 +48,9 @@ Vagrantfile
|
|||
# Configs
|
||||
*.hcl
|
||||
!command/agent/config/test-fixtures/config.hcl
|
||||
!command/agent/config/test-fixtures/config-cache.hcl
|
||||
!command/agent/config/test-fixtures/config-embedded-type.hcl
|
||||
!command/agent/config/test-fixtures/config-cache-embedded-type.hcl
|
||||
|
||||
.DS_Store
|
||||
.idea
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
|
||||
const EnvVaultAddress = "VAULT_ADDR"
|
||||
const EnvVaultCACert = "VAULT_CACERT"
|
||||
const EnvVaultCAPath = "VAULT_CAPATH"
|
||||
|
@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error {
|
|||
if v := os.Getenv(EnvVaultAddress); v != "" {
|
||||
envAddress = v
|
||||
}
|
||||
// Agent's address will take precedence over Vault's address
|
||||
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
|
||||
envAddress = v
|
||||
}
|
||||
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
|
||||
maxRetries, err := strconv.ParseUint(v, 10, 32)
|
||||
if err != nil {
|
||||
|
@ -366,6 +371,21 @@ func NewClient(c *Config) (*Client, error) {
|
|||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
// If address begins with a `unix://`, treat it as a socket file path and set
|
||||
// the HttpClient's transport to the corresponding socket dialer.
|
||||
if strings.HasPrefix(c.Address, "unix://") {
|
||||
socketFilePath := strings.TrimPrefix(c.Address, "unix://")
|
||||
c.HttpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(context.Context, string, string) (net.Conn, error) {
|
||||
return net.Dial("unix", socketFilePath)
|
||||
},
|
||||
},
|
||||
}
|
||||
// Set the unix address for URL parsing below
|
||||
c.Address = "http://unix"
|
||||
}
|
||||
|
||||
u, err := url.Parse(c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -707,7 +727,7 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
|
|||
|
||||
redirectCount := 0
|
||||
START:
|
||||
req, err := r.toRetryableHTTP()
|
||||
req, err := r.ToRetryableHTTP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ func (r *Request) ResetJSONBody() error {
|
|||
// DEPRECATED: ToHTTP turns this request into a valid *http.Request for use
|
||||
// with the net/http package.
|
||||
func (r *Request) ToHTTP() (*http.Request, error) {
|
||||
req, err := r.toRetryableHTTP()
|
||||
req, err := r.ToRetryableHTTP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func (r *Request) ToHTTP() (*http.Request, error) {
|
|||
return req.Request, nil
|
||||
}
|
||||
|
||||
func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
|
||||
func (r *Request) ToRetryableHTTP() (*retryablehttp.Request, error) {
|
||||
// Encode the query parameters
|
||||
r.URL.RawQuery = r.Params.Encode()
|
||||
|
||||
|
|
|
@ -292,6 +292,7 @@ type SecretAuth struct {
|
|||
TokenPolicies []string `json:"token_policies"`
|
||||
IdentityPolicies []string `json:"identity_policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
Orphan bool `json:"orphan"`
|
||||
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
|
|
102
command/agent.go
102
command/agent.go
|
@ -4,6 +4,10 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -23,6 +27,7 @@ import (
|
|||
"github.com/hashicorp/vault/command/agent/auth/gcp"
|
||||
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
|
@ -218,19 +223,6 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
info["cgo"] = "enabled"
|
||||
}
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
c.UI.Output("==> Vault agent configuration:\n")
|
||||
for _, k := range infoKeys {
|
||||
c.UI.Output(fmt.Sprintf(
|
||||
"%s%s: %s",
|
||||
strings.Repeat(" ", padding-len(k)),
|
||||
strings.Title(k),
|
||||
info[k]))
|
||||
}
|
||||
c.UI.Output("")
|
||||
|
||||
// Tests might not want to start a vault server and just want to verify
|
||||
// the configuration.
|
||||
if c.flagTestVerifyOnly {
|
||||
|
@ -332,10 +324,92 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
|
||||
})
|
||||
|
||||
// Start things running
|
||||
// Start auto-auth and sink servers
|
||||
go ah.Run(ctx, method)
|
||||
go ss.Run(ctx, ah.OutputCh, sinks)
|
||||
|
||||
// Parse agent listener configurations
|
||||
if config.Cache != nil && len(config.Cache.Listeners) != 0 {
|
||||
cacheLogger := c.logger.Named("cache")
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, config.Cache.UseAutoAuthToken, c.client))
|
||||
|
||||
var listeners []net.Listener
|
||||
for i, lnConfig := range config.Cache.Listeners {
|
||||
listener, props, _, err := cache.ServerListener(lnConfig, c.logWriter, c.UI)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error parsing listener configuration: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
listeners = append(listeners, listener)
|
||||
|
||||
scheme := "https://"
|
||||
if props["tls"] == "disabled" {
|
||||
scheme = "http://"
|
||||
}
|
||||
if lnConfig.Type == "unix" {
|
||||
scheme = "unix://"
|
||||
}
|
||||
|
||||
infoKey := fmt.Sprintf("api address %d", i+1)
|
||||
info[infoKey] = scheme + listener.Addr().String()
|
||||
infoKeys = append(infoKeys, infoKey)
|
||||
|
||||
cacheLogger.Info("starting listener", "addr", listener.Addr().String())
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
}
|
||||
|
||||
// Ensure that listeners are closed at all the exits
|
||||
listenerCloseFunc := func() {
|
||||
for _, ln := range listeners {
|
||||
ln.Close()
|
||||
}
|
||||
}
|
||||
defer c.cleanupGuard.Do(listenerCloseFunc)
|
||||
}
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
c.UI.Output("==> Vault agent configuration:\n")
|
||||
for _, k := range infoKeys {
|
||||
c.UI.Output(fmt.Sprintf(
|
||||
"%s%s: %s",
|
||||
strings.Repeat(" ", padding-len(k)),
|
||||
strings.Title(k),
|
||||
info[k]))
|
||||
}
|
||||
c.UI.Output("")
|
||||
|
||||
// Release the log gate.
|
||||
c.logGate.Flush()
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// APIProxy is an implementation of the proxier interface that is used to
|
||||
// forward the request to Vault and get the response.
|
||||
type APIProxy struct {
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
type APIProxyConfig struct {
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
func NewAPIProxy(config *APIProxyConfig) Proxier {
|
||||
return &APIProxy{
|
||||
logger: config.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.SetToken(req.Token)
|
||||
client.SetHeaders(req.Request.Header)
|
||||
|
||||
fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path)
|
||||
fwReq.BodyBytes = req.RequestBody
|
||||
|
||||
// Make the request to Vault and get the response
|
||||
ap.logger.Info("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
resp, err := client.RawRequestWithContext(ctx, fwReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and reset response body
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ap.logger.Error("failed to read request body", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody))
|
||||
|
||||
return &SendResponse{
|
||||
Response: resp,
|
||||
ResponseBody: respBody,
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
)
|
||||
|
||||
func TestCache_APIProxy(t *testing.T) {
|
||||
cleanup, client, _, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
|
||||
defer cleanup()
|
||||
|
||||
proxier := NewAPIProxy(&APIProxyConfig{
|
||||
Logger: logging.NewVaultLogger(hclog.Trace),
|
||||
})
|
||||
|
||||
r := client.NewRequest("GET", "/v1/sys/health")
|
||||
req, err := r.ToRetryableHTTP()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{
|
||||
Request: req.Request,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result api.HealthResponse
|
||||
err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !result.Initialized || result.Sealed || result.Standby {
|
||||
t.Fatalf("bad sys/health response")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,926 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
kv "github.com/hashicorp/vault-plugin-secrets-kv"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
const policyAdmin = `
|
||||
path "*" {
|
||||
capabilities = ["sudo", "create", "read", "update", "delete", "list"]
|
||||
}
|
||||
`
|
||||
|
||||
// setupClusterAndAgent is a helper func used to set up a test cluster and
|
||||
// caching agent. It returns a cleanup func that should be deferred immediately
|
||||
// along with two clients, one for direct cluster communication and another to
|
||||
// talk to the caching agent.
|
||||
func setupClusterAndAgent(ctx context.Context, t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client, *LeaseCache) {
|
||||
t.Helper()
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Handle sane defaults
|
||||
if coreConfig == nil {
|
||||
coreConfig = &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: logging.NewVaultLogger(hclog.Trace),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if coreConfig.CredentialBackends == nil {
|
||||
coreConfig.CredentialBackends = map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
}
|
||||
}
|
||||
|
||||
// Init new test cluster
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
|
||||
cores := cluster.Cores
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
// clusterClient is the client that is used to talk directly to the cluster.
|
||||
clusterClient := cores[0].Client
|
||||
|
||||
// Add an admin policy
|
||||
if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up the userpass auth backend and an admin user. Used for getting a token
|
||||
// for the agent later down in this func.
|
||||
clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
||||
Type: "userpass",
|
||||
})
|
||||
|
||||
_, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
"policies": []string{"admin"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up env vars for agent consumption
|
||||
origEnvVaultAddress := os.Getenv(api.EnvVaultAddress)
|
||||
os.Setenv(api.EnvVaultAddress, clusterClient.Address())
|
||||
|
||||
origEnvVaultCACert := os.Getenv(api.EnvVaultCACert)
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := NewAPIProxy(&APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, false, clusterClient))
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
|
||||
// testClient is the client that is used to talk to the agent for proxying/caching behavior.
|
||||
testClient, err := clusterClient.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Login via userpass method to derive a managed token. Set that token as the
|
||||
// testClient's token
|
||||
resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testClient.SetToken(resp.Auth.ClientToken)
|
||||
|
||||
cleanup := func() {
|
||||
cluster.Cleanup()
|
||||
os.Setenv(api.EnvVaultAddress, origEnvVaultAddress)
|
||||
os.Setenv(api.EnvVaultCACert, origEnvVaultCACert)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
return cleanup, clusterClient, testClient, leaseCache
|
||||
}
|
||||
|
||||
func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expected map[string]string, leaseCache *LeaseCache) {
|
||||
t.Helper()
|
||||
for val, valType := range sampleSpace {
|
||||
index, err := leaseCache.db.Get(valType, val)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expected[val] == "" && index != nil {
|
||||
t.Fatalf("failed to evict index from the cache: type: %q, value: %q", valType, val)
|
||||
}
|
||||
if expected[val] != "" && index == nil {
|
||||
t.Fatalf("evicted an undesired index from cache: type: %q, value: %q", valType, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_RevokeOrphan(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke-orphan the intermediate token. This should result in its own
|
||||
// eviction and evictions of the revoked token's leases. All other things
|
||||
// including the child tokens and leases of the child tokens should be
|
||||
// untouched.
|
||||
testClient.SetToken(token2)
|
||||
err = testClient.Auth().Token().RevokeOrphan(token2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
token3: "token",
|
||||
lease3: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_LeafLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the lef token. This should evict all the leases belonging to this
|
||||
// token, evict entries for all the child tokens and their respective
|
||||
// leases.
|
||||
testClient.SetToken(token3)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
token2: "token",
|
||||
lease2: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_IntermediateLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the second level token. This should evict all the leases
|
||||
// belonging to this token, evict entries for all the child tokens and
|
||||
// their respective leases.
|
||||
testClient.SetToken(token2)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_TopLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the top level token. This should evict all the leases belonging
|
||||
// to this token, evict entries for all the child tokens and their
|
||||
// respective leases.
|
||||
testClient.SetToken(token1)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_Shutdown(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
ctx, rootCancelFunc := context.WithCancel(namespace.RootContext(nil))
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(ctx, t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
rootCancelFunc()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Ensure that all the entries are now gone
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_BaseContextCancellation(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Cancel the base context of the lease cache. This should trigger
|
||||
// evictions of all the entries from the cache.
|
||||
leaseCache.baseCtxInfo.CancelFunc()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Ensure that all the entries are now gone
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_NonCacheable(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": kv.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
// Query mounts first
|
||||
origMounts, err := testClient.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Mount a kv backend
|
||||
if err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
Options: map[string]string{
|
||||
"version": "2",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Query mounts again
|
||||
newMounts, err := testClient.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(origMounts, newMounts); diff == nil {
|
||||
t.Logf("response #1: %#v", origMounts)
|
||||
t.Logf("response #2: %#v", newMounts)
|
||||
t.Fatal("expected requests to be not cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_AuthResponse(t *testing.T) {
|
||||
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
|
||||
defer cleanup()
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := resp.Auth.ClientToken
|
||||
testClient.SetToken(token)
|
||||
|
||||
authTokeCreateReq := func(t *testing.T, policies map[string]interface{}) *api.Secret {
|
||||
resp, err := testClient.Logical().Write("auth/token/create", policies)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Auth == nil || resp.Auth.ClientToken == "" {
|
||||
t.Fatalf("expected a valid client token in the response, got = %#v", resp)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Test on auth response by creating a child token
|
||||
{
|
||||
proxiedResp := authTokeCreateReq(t, map[string]interface{}{
|
||||
"policies": "default",
|
||||
})
|
||||
|
||||
cachedResp := authTokeCreateReq(t, map[string]interface{}{
|
||||
"policies": "default",
|
||||
})
|
||||
|
||||
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test on *non-renewable* auth response by creating a child root token
|
||||
{
|
||||
proxiedResp := authTokeCreateReq(t, nil)
|
||||
|
||||
cachedResp := authTokeCreateReq(t, nil)
|
||||
|
||||
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseResponse(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
cleanup, client, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
err := client.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test proxy by issuing two different requests
|
||||
{
|
||||
// Write data to the lease-kv backend
|
||||
_, err := testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
firstResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secondResp, err := testClient.Logical().Read("kv/foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(firstResp, secondResp); diff == nil {
|
||||
t.Logf("response: %#v", firstResp)
|
||||
t.Fatal("expected proxied responses, got cached response on second request")
|
||||
}
|
||||
}
|
||||
|
||||
// Test caching behavior by issue the same request twice
|
||||
{
|
||||
_, err := testClient.Logical().Write("kv/baz", map[string]interface{}{
|
||||
"value": "foo",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
proxiedResp, err := testClient.Logical().Read("kv/baz")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cachedResp, err := testClient.Logical().Read("kv/baz")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(proxiedResp, cachedResp); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,265 @@
|
|||
package cachememdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
memdb "github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
tableNameIndexer = "indexer"
|
||||
)
|
||||
|
||||
// CacheMemDB is the underlying cache database for storing indexes.
|
||||
type CacheMemDB struct {
|
||||
db *memdb.MemDB
|
||||
}
|
||||
|
||||
// New creates a new instance of CacheMemDB.
|
||||
func New() (*CacheMemDB, error) {
|
||||
db, err := newDB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CacheMemDB{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDB() (*memdb.MemDB, error) {
|
||||
cacheSchema := &memdb.DBSchema{
|
||||
Tables: map[string]*memdb.TableSchema{
|
||||
tableNameIndexer: &memdb.TableSchema{
|
||||
Name: tableNameIndexer,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
// This index enables fetching the cached item based on the
|
||||
// identifier of the index.
|
||||
IndexNameID: &memdb.IndexSchema{
|
||||
Name: IndexNameID,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// a given request path, in a given namespace.
|
||||
IndexNameRequestPath: &memdb.IndexSchema{
|
||||
Name: IndexNameRequestPath,
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "Namespace",
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "RequestPath",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache
|
||||
// belonging to the leases of a given token.
|
||||
IndexNameLeaseToken: &memdb.IndexSchema{
|
||||
Name: IndexNameLeaseToken,
|
||||
Unique: false,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "LeaseToken",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache
|
||||
// that are tied to the given token, regardless of the
|
||||
// entries belonging to the token or belonging to the
|
||||
// lease.
|
||||
IndexNameToken: &memdb.IndexSchema{
|
||||
Name: IndexNameToken,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Token",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given parent token.
|
||||
IndexNameTokenParent: &memdb.IndexSchema{
|
||||
Name: IndexNameTokenParent,
|
||||
Unique: false,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "TokenParent",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given accessor.
|
||||
IndexNameTokenAccessor: &memdb.IndexSchema{
|
||||
Name: IndexNameTokenAccessor,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "TokenAccessor",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given lease identifier.
|
||||
IndexNameLease: &memdb.IndexSchema{
|
||||
Name: IndexNameLease,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Lease",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db, err := memdb.NewMemDB(cacheSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Get returns the index based on the indexer and the index values provided.
|
||||
func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) {
|
||||
if !validIndexName(indexName) {
|
||||
return nil, fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
index, ok := raw.(*Index)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to parse index value from the cache")
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// Set stores the index into the cache.
|
||||
func (c *CacheMemDB) Set(index *Index) error {
|
||||
if index == nil {
|
||||
return errors.New("nil index provided")
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
if err := txn.Insert(tableNameIndexer, index); err != nil {
|
||||
return fmt.Errorf("unable to insert index into cache: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByPrefix returns all the cached indexes based on the index name and the
|
||||
// value prefix.
|
||||
func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) {
|
||||
if !validIndexName(indexName) {
|
||||
return nil, fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
indexName = indexName + "_prefix"
|
||||
|
||||
// Get all the objects
|
||||
iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []*Index
|
||||
for {
|
||||
obj := iter.Next()
|
||||
if obj == nil {
|
||||
break
|
||||
}
|
||||
index, ok := obj.(*Index)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to cast cached index")
|
||||
}
|
||||
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
// Evict removes an index from the cache based on index name and value.
|
||||
func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error {
|
||||
index, err := c.Get(indexName, indexValues...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch index on cache deletion: %v", err)
|
||||
}
|
||||
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
if err := txn.Delete(tableNameIndexer, index); err != nil {
|
||||
return fmt.Errorf("unable to delete index from cache: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EvictAll removes all matching indexes from the cache based on index name and value.
|
||||
func (c *CacheMemDB) EvictAll(indexName, indexValue string) error {
|
||||
return c.batchEvict(false, indexName, indexValue)
|
||||
}
|
||||
|
||||
// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix.
|
||||
func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error {
|
||||
return c.batchEvict(true, indexName, indexPrefix)
|
||||
}
|
||||
|
||||
// batchEvict is a helper that supports eviction based on absolute and prefixed index values.
|
||||
func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error {
|
||||
if !validIndexName(indexName) {
|
||||
return fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
if isPrefix {
|
||||
indexName = indexName + "_prefix"
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
_, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush resets the underlying cache object.
|
||||
func (c *CacheMemDB) Flush() error {
|
||||
newDB, err := newDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.db = newDB
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,388 @@
|
|||
package cachememdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
)
|
||||
|
||||
func testContextInfo() *ContextInfo {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
return &ContextInfo{
|
||||
Ctx: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
_, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Get(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test invalid index name
|
||||
_, err = cache.Get("foo", "bar")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
index, err := cache.Get(IndexNameID, "foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil index, got: %v", index)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "test_lease",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
}{
|
||||
{
|
||||
"by_index_id",
|
||||
"id",
|
||||
[]interface{}{in.ID},
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{in.Namespace, in.RequestPath},
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{in.Lease},
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
[]interface{}{in.Token},
|
||||
},
|
||||
{
|
||||
"by_token_accessor",
|
||||
"token_accessor",
|
||||
[]interface{}{in.TokenAccessor},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out, err := cache.Get(tc.indexName, tc.indexValues...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(in, out); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_GetByPrefix(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test invalid index name
|
||||
_, err = cache.GetByPrefix("foo", "bar", "baz")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
index, err := cache.GetByPrefix(IndexNameRequestPath, "foo", "bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil index, got: %v", index)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path/1",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "path/to/test_lease/1",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in2 := &Index{
|
||||
ID: "test_id_2",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path/2",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "path/to/test_lease/2",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
}{
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{"test_ns/", "/v1/request/path"},
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{"path/to/test_lease"},
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
[]interface{}{"test_token"},
|
||||
},
|
||||
{
|
||||
"by_token_accessor",
|
||||
"token_accessor",
|
||||
[]interface{}{"test_accessor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out, err := cache.GetByPrefix(tc.indexName, tc.indexValues...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal([]*Index{in, in2}, out); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Set(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
index *Index
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty_fields",
|
||||
&Index{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing_required_fields",
|
||||
&Index{
|
||||
Lease: "foo",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"all_fields",
|
||||
&Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "test_lease",
|
||||
RenewCtxInfo: testContextInfo(),
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if err := cache.Set(tc.index); (err != nil) != tc.wantErr {
|
||||
t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Evict(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
if err := cache.Evict(IndexNameID, "foo"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testIndex := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_token_accessor",
|
||||
Lease: "test_lease",
|
||||
RenewCtxInfo: testContextInfo(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
insertIndex *Index
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"empty_params",
|
||||
"",
|
||||
[]interface{}{""},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid_params",
|
||||
"foo",
|
||||
[]interface{}{"bar"},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"by_id",
|
||||
"id",
|
||||
[]interface{}{"test_id"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{"test_ns/", "/v1/request/path"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
[]interface{}{"test_token"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_token_accessor",
|
||||
"token_accessor",
|
||||
[]interface{}{"test_accessor"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{"test_lease"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.insertIndex != nil {
|
||||
if err := cache.Set(tc.insertIndex); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify that the cache doesn't contain the entry any more
|
||||
index, err := cache.Get(tc.indexName, tc.indexValues...)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil entry, got = %#v", index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Flush(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Token: "test_token",
|
||||
Lease: "test_lease",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset the cache
|
||||
if err := cache.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check the cache doesn't contain inserted index
|
||||
out, err := cache.Get(IndexNameID, "test_id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected cache to be empty, got = %v", out)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package cachememdb
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextInfo struct {
|
||||
Ctx context.Context
|
||||
CancelFunc context.CancelFunc
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// Index holds the response to be cached along with multiple other values that
|
||||
// serve as pointers to refer back to this index.
|
||||
type Index struct {
|
||||
// ID is a value that uniquely represents the request held by this
|
||||
// index. This is computed by serializing and hashing the response object.
|
||||
// Required: true, Unique: true
|
||||
ID string
|
||||
|
||||
// Token is the token that fetched the response held by this index
|
||||
// Required: true, Unique: true
|
||||
Token string
|
||||
|
||||
// TokenParent is the parent token of the token held by this index
|
||||
// Required: false, Unique: false
|
||||
TokenParent string
|
||||
|
||||
// TokenAccessor is the accessor of the token being cached in this index
|
||||
// Required: true, Unique: true
|
||||
TokenAccessor string
|
||||
|
||||
// Namespace is the namespace that was provided in the request path as the
|
||||
// Vault namespace to query
|
||||
Namespace string
|
||||
|
||||
// RequestPath is the path of the request that resulted in the response
|
||||
// held by this index.
|
||||
// Required: true, Unique: false
|
||||
RequestPath string
|
||||
|
||||
// Lease is the identifier of the lease in Vault, that belongs to the
|
||||
// response held by this index.
|
||||
// Required: false, Unique: true
|
||||
Lease string
|
||||
|
||||
// LeaseToken is the identifier of the token that created the lease held by
|
||||
// this index.
|
||||
// Required: false, Unique: false
|
||||
LeaseToken string
|
||||
|
||||
// Response is the serialized response object that the agent is caching.
|
||||
Response []byte
|
||||
|
||||
// RenewCtxInfo holds the context and the corresponding cancel func for the
|
||||
// goroutine that manages the renewal of the secret belonging to the
|
||||
// response in this index.
|
||||
RenewCtxInfo *ContextInfo
|
||||
}
|
||||
|
||||
type IndexName uint32
|
||||
|
||||
const (
|
||||
// IndexNameID is the ID of the index constructed from the serialized request.
|
||||
IndexNameID = "id"
|
||||
|
||||
// IndexNameLease is the lease of the index.
|
||||
IndexNameLease = "lease"
|
||||
|
||||
// IndexNameRequestPath is the request path of the index.
|
||||
IndexNameRequestPath = "request_path"
|
||||
|
||||
// IndexNameToken is the token of the index.
|
||||
IndexNameToken = "token"
|
||||
|
||||
// IndexNameTokenAccessor is the token accessor of the index.
|
||||
IndexNameTokenAccessor = "token_accessor"
|
||||
|
||||
// IndexNameTokenParent is the token parent of the index.
|
||||
IndexNameTokenParent = "token_parent"
|
||||
|
||||
// IndexNameLeaseToken is the token that created the lease.
|
||||
IndexNameLeaseToken = "lease_token"
|
||||
)
|
||||
|
||||
func validIndexName(indexName string) bool {
|
||||
switch indexName {
|
||||
case "id":
|
||||
case "lease":
|
||||
case "request_path":
|
||||
case "token":
|
||||
case "token_accessor":
|
||||
case "token_parent":
|
||||
case "lease_token":
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, useAutoAuthToken bool, client *api.Client) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("received request", "path", r.URL.Path, "method", r.Method)
|
||||
|
||||
token := r.Header.Get(consts.AuthHeaderName)
|
||||
if token == "" && useAutoAuthToken {
|
||||
logger.Debug("using auto auth token")
|
||||
token = client.Token()
|
||||
}
|
||||
|
||||
// Parse and reset body.
|
||||
reqBody, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.Error("failed to read request body")
|
||||
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
|
||||
}
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody))
|
||||
req := &SendRequest{
|
||||
Token: token,
|
||||
Request: r,
|
||||
RequestBody: reqBody,
|
||||
}
|
||||
|
||||
resp, err := proxier.Send(ctx, req)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = processTokenLookupResponse(ctx, logger, useAutoAuthToken, client, req, resp)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to process token lookup response: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Response.Body.Close()
|
||||
|
||||
copyHeader(w.Header(), resp.Response.Header)
|
||||
w.WriteHeader(resp.Response.StatusCode)
|
||||
io.Copy(w, resp.Response.Body)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// processTokenLookupResponse checks if the request was one of token
|
||||
// lookup-self. If the auto-auth token was used to perform lookup-self, the
|
||||
// identifier of the token and its accessor same will be stripped off of the
|
||||
// response.
|
||||
func processTokenLookupResponse(ctx context.Context, logger hclog.Logger, useAutoAuthToken bool, client *api.Client, req *SendRequest, resp *SendResponse) error {
|
||||
// If auto-auth token is not being used, there is nothing to do.
|
||||
if !useAutoAuthToken {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If lookup responded with non 200 status, there is nothing to do.
|
||||
if resp.Response.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip-off namespace related information from the request and get the
|
||||
// relative path of the request.
|
||||
_, path := deriveNamespaceAndRevocationPath(req)
|
||||
if path == vaultPathTokenLookupSelf {
|
||||
logger.Info("stripping auto-auth token from the response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse token lookup response: %v", err)
|
||||
}
|
||||
if secret != nil && secret.Data != nil && secret.Data["id"] != nil {
|
||||
token, ok := secret.Data["id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to type assert the token id in the response")
|
||||
}
|
||||
if token == client.Token() {
|
||||
delete(secret.Data, "id")
|
||||
delete(secret.Data, "accessor")
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Response.Body != nil {
|
||||
resp.Response.Body.Close()
|
||||
}
|
||||
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
resp.Response.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// Serialize and re-read the reponse
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Response.Write(&respBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize the updated response: %v", err)
|
||||
}
|
||||
|
||||
updatedResponse, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBytes.Bytes())), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deserialize the updated response: %v", err)
|
||||
}
|
||||
|
||||
resp.Response = &api.Response{
|
||||
Response: updatedResponse,
|
||||
}
|
||||
resp.ResponseBody = bodyBytes
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
logical.AdjustErrorStatusCode(&status, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)}
|
||||
if err != nil {
|
||||
resp.Errors = append(resp.Errors, err.Error())
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(resp)
|
||||
}
|
|
@ -0,0 +1,813 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
nshelper "github.com/hashicorp/vault/helper/namespace"
|
||||
)
|
||||
|
||||
const (
|
||||
vaultPathTokenCreate = "/v1/auth/token/create"
|
||||
vaultPathTokenRevoke = "/v1/auth/token/revoke"
|
||||
vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self"
|
||||
vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor"
|
||||
vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan"
|
||||
vaultPathTokenLookupSelf = "/v1/auth/token/lookup-self"
|
||||
vaultPathLeaseRevoke = "/v1/sys/leases/revoke"
|
||||
vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force"
|
||||
vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix"
|
||||
)
|
||||
|
||||
var (
|
||||
contextIndexID = contextIndex{}
|
||||
errInvalidType = errors.New("invalid type provided")
|
||||
revocationPaths = []string{
|
||||
strings.TrimPrefix(vaultPathTokenRevoke, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"),
|
||||
}
|
||||
)
|
||||
|
||||
type contextIndex struct{}
|
||||
|
||||
type cacheClearRequest struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Namespace string `json:"namespace"`
|
||||
}
|
||||
|
||||
// LeaseCache is an implementation of Proxier that handles
|
||||
// the caching of responses. It passes the incoming request
|
||||
// to an underlying Proxier implementation.
|
||||
type LeaseCache struct {
|
||||
proxier Proxier
|
||||
logger hclog.Logger
|
||||
db *cachememdb.CacheMemDB
|
||||
baseCtxInfo *ContextInfo
|
||||
}
|
||||
|
||||
// LeaseCacheConfig is the configuration for initializing a new
|
||||
// Lease.
|
||||
type LeaseCacheConfig struct {
|
||||
BaseContext context.Context
|
||||
Proxier Proxier
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
// ContextInfo holds a derived context and cancelFunc pair.
|
||||
type ContextInfo struct {
|
||||
Ctx context.Context
|
||||
CancelFunc context.CancelFunc
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// NewLeaseCache creates a new instance of a LeaseCache.
|
||||
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("nil configuration provided")
|
||||
}
|
||||
|
||||
if conf.Proxier == nil || conf.Logger == nil {
|
||||
return nil, fmt.Errorf("missing configuration required params: %v", conf)
|
||||
}
|
||||
|
||||
db, err := cachememdb.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a base context for the lease cache layer
|
||||
baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext)
|
||||
baseCtxInfo := &ContextInfo{
|
||||
Ctx: baseCtx,
|
||||
CancelFunc: baseCancelFunc,
|
||||
}
|
||||
|
||||
return &LeaseCache{
|
||||
proxier: conf.Proxier,
|
||||
logger: conf.Logger,
|
||||
db: db,
|
||||
baseCtxInfo: baseCtxInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send performs a cache lookup on the incoming request. If it's a cache hit,
|
||||
// it will return the cached response, otherwise it will delegate to the
|
||||
// underlying Proxier and cache the received response.
|
||||
func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
// Compute the index ID
|
||||
id, err := computeIndexID(req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to compute cache key", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the response for this request is already in the cache
|
||||
index, err := c.db.Get(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cached request is found, deserialize the response and return early
|
||||
if index != nil {
|
||||
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
|
||||
|
||||
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to deserialize response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: resp,
|
||||
},
|
||||
ResponseBody: index.Response,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
|
||||
// Pass the request down and get a response
|
||||
resp, err := c.proxier.Send(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the namespace from the request header
|
||||
namespace := req.Request.Header.Get(consts.NamespaceHeaderName)
|
||||
// We need to populate an empty value since go-memdb will skip over indexes
|
||||
// that contain empty values.
|
||||
if namespace == "" {
|
||||
namespace = "root/"
|
||||
}
|
||||
|
||||
// Build the index to cache based on the response received
|
||||
index = &cachememdb.Index{
|
||||
ID: id,
|
||||
Namespace: namespace,
|
||||
RequestPath: req.Request.URL.Path,
|
||||
}
|
||||
|
||||
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
|
||||
if err != nil {
|
||||
c.logger.Error("failed to parse response as secret", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
isRevocation, err := c.handleRevocationRequest(ctx, req, resp)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to process the response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If this is a revocation request, do not go through cache logic.
|
||||
if isRevocation {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Fast path for responses with no secrets
|
||||
if secret == nil {
|
||||
c.logger.Debug("pass-through response; no secret in response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Short-circuit if the secret is not renewable
|
||||
tokenRenewable, err := secret.TokenIsRenewable()
|
||||
if err != nil {
|
||||
c.logger.Error("failed to parse renewable param", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if !secret.Renewable && !tokenRenewable {
|
||||
c.logger.Debug("pass-through response; secret not renewable", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var renewCtxInfo *ContextInfo
|
||||
switch {
|
||||
case secret.LeaseID != "":
|
||||
c.logger.Debug("processing lease response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If the lease belongs to a token that is not managed by the agent,
|
||||
// return the response without caching it.
|
||||
if entry == nil {
|
||||
c.logger.Debug("pass-through lease response; token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Derive a context for renewal using the token's context
|
||||
newCtxInfo := new(ContextInfo)
|
||||
newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(entry.RenewCtxInfo.Ctx)
|
||||
newCtxInfo.DoneCh = make(chan struct{})
|
||||
renewCtxInfo = newCtxInfo
|
||||
|
||||
index.Lease = secret.LeaseID
|
||||
index.LeaseToken = req.Token
|
||||
|
||||
case secret.Auth != nil:
|
||||
c.logger.Debug("processing auth response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan
|
||||
|
||||
// If the new token is a result of token creation endpoints (not from
|
||||
// login endpoints), and if its a non-orphan, then the new token's
|
||||
// context should be derived from the context of the parent token.
|
||||
var parentCtx context.Context
|
||||
if isNonOrphanNewToken {
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If parent token is not managed by the agent, child shouldn't be
|
||||
// either.
|
||||
if entry == nil {
|
||||
c.logger.Debug("pass-through auth response; parent token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("setting parent context", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
parentCtx = entry.RenewCtxInfo.Ctx
|
||||
|
||||
entry.TokenParent = req.Token
|
||||
}
|
||||
|
||||
renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken)
|
||||
index.Token = secret.Auth.ClientToken
|
||||
index.TokenAccessor = secret.Auth.Accessor
|
||||
|
||||
default:
|
||||
// We shouldn't be hitting this, but will err on the side of caution and
|
||||
// simply proxy.
|
||||
c.logger.Debug("pass-through response; secret without lease and token", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Serialize the response to store it in the cached index
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Response.Write(&respBytes)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to serialize response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the response body for upper layers to read
|
||||
if resp.Response.Body != nil {
|
||||
resp.Response.Body.Close()
|
||||
}
|
||||
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody))
|
||||
|
||||
// Set the index's Response
|
||||
index.Response = respBytes.Bytes()
|
||||
|
||||
// Store the index ID in the renewer context
|
||||
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
|
||||
|
||||
// Store the renewer context in the index
|
||||
index.RenewCtxInfo = &cachememdb.ContextInfo{
|
||||
Ctx: renewCtx,
|
||||
CancelFunc: renewCtxInfo.CancelFunc,
|
||||
DoneCh: renewCtxInfo.DoneCh,
|
||||
}
|
||||
|
||||
// Store the index in the cache
|
||||
c.logger.Debug("storing response into the cache", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start renewing the secret in the response
|
||||
go c.startRenewing(renewCtx, index, req, secret)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo {
|
||||
if ctx == nil {
|
||||
ctx = c.baseCtxInfo.Ctx
|
||||
}
|
||||
ctxInfo := new(ContextInfo)
|
||||
ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx)
|
||||
ctxInfo.DoneCh = make(chan struct{})
|
||||
return ctxInfo
|
||||
}
|
||||
|
||||
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
|
||||
defer func() {
|
||||
id := ctx.Value(contextIndexID).(string)
|
||||
c.logger.Debug("evicting index from cache", "id", id, "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
err := c.db.Evict(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to evict index", "id", id, "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create API client in the renewer", "error", err)
|
||||
return
|
||||
}
|
||||
client.SetToken(req.Token)
|
||||
client.SetHeaders(req.Request.Header)
|
||||
|
||||
renewer, err := client.NewRenewer(&api.RenewerInput{
|
||||
Secret: secret,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create secret renewer", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("initiating renewal", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
go renewer.Renew()
|
||||
defer renewer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This is the case which captures context cancellations from token
|
||||
// and leases. Since all the contexts are derived from the agent's
|
||||
// context, this will also cover the shutdown scenario.
|
||||
c.logger.Debug("context cancelled; stopping renewer", "path", req.Request.URL.Path)
|
||||
return
|
||||
case err := <-renewer.DoneCh():
|
||||
// This case covers renewal completion and renewal errors
|
||||
if err != nil {
|
||||
c.logger.Error("failed to renew secret", "error", err)
|
||||
return
|
||||
}
|
||||
c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path)
|
||||
return
|
||||
case renewal := <-renewer.RenewCh():
|
||||
// This case captures secret renewals. Renewed secret is updated in
|
||||
// the cached index.
|
||||
c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path)
|
||||
err = c.updateResponse(ctx, renewal)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to handle renewal", "error", err)
|
||||
return
|
||||
}
|
||||
case <-index.RenewCtxInfo.DoneCh:
|
||||
// This case indicates the renewal process to shutdown and evict
|
||||
// the cache entry. This is triggered when a specific secret
|
||||
// renewal needs to be killed without affecting any of the derived
|
||||
// context renewals.
|
||||
c.logger.Debug("done channel closed")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error {
|
||||
id := ctx.Value(contextIndexID).(string)
|
||||
|
||||
// Get the cached index using the id in the context
|
||||
index, err := c.db.Get(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return fmt.Errorf("missing cache entry for id: %q", id)
|
||||
}
|
||||
|
||||
// Read the response from the index
|
||||
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to deserialize response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the body in the reponse by the renewed secret
|
||||
bodyBytes, err := json.Marshal(renewal.Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
resp.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// Serialize the response
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Write(&respBytes)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to serialize updated response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the response in the index and set it in the cache
|
||||
index.Response = respBytes.Bytes()
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeIndexID results in a value that uniquely identifies a request
|
||||
// received by the agent. It does so by SHA256 hashing the serialized request
|
||||
// object containing the request path, query parameters and body parameters.
|
||||
func computeIndexID(req *SendRequest) (string, error) {
|
||||
var b bytes.Buffer
|
||||
|
||||
// Serialze the request
|
||||
if err := req.Request.Write(&b); err != nil {
|
||||
return "", fmt.Errorf("failed to serialize request: %v", err)
|
||||
}
|
||||
|
||||
// Reset the request body after it has been closed by Write
|
||||
if req.Request.Body != nil {
|
||||
req.Request.Body.Close()
|
||||
}
|
||||
req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody))
|
||||
|
||||
// Append req.Token into the byte slice. This is needed since auto-auth'ed
|
||||
// requests sets the token directly into SendRequest.Token
|
||||
b.Write([]byte(req.Token))
|
||||
|
||||
sum := sha256.Sum256(b.Bytes())
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
// HandleCacheClear returns a handlerFunc that can perform cache clearing operations.
|
||||
func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req := new(cacheClearRequest)
|
||||
if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil {
|
||||
if err == io.EOF {
|
||||
err = errors.New("empty JSON provided")
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value)
|
||||
|
||||
if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil {
|
||||
// Default to 500 on error, unless the user provided an invalid type,
|
||||
// which would then be a 400.
|
||||
httpStatus := http.StatusInternalServerError
|
||||
if err == errInvalidType {
|
||||
httpStatus = http.StatusBadRequest
|
||||
}
|
||||
respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error {
|
||||
if len(clearValues) == 0 {
|
||||
return errors.New("no value(s) provided to clear corresponding cache entries")
|
||||
}
|
||||
|
||||
// The value that we want to clear, for most cases, is the last one provided.
|
||||
clearValue, ok := clearValues[len(clearValues)-1].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to convert %v to type string", clearValue)
|
||||
}
|
||||
|
||||
switch clearType {
|
||||
case "request_path":
|
||||
// For this particular case, we need to ensure that there are 2 provided
|
||||
// indexers for the proper lookup.
|
||||
if len(clearValues) != 2 {
|
||||
return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues))
|
||||
}
|
||||
|
||||
// The first value provided for this case will be the namespace, but if it's
|
||||
// an empty value we need to overwrite it with "root/" to ensure proper
|
||||
// cache lookup.
|
||||
if clearValues[0].(string) == "" {
|
||||
clearValues[0] = "root/"
|
||||
}
|
||||
|
||||
// Find all the cached entries which has the given request path and
|
||||
// cancel the contexts of all the respective renewers
|
||||
indexes, err := c.db.GetByPrefix(clearType, clearValues...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
|
||||
case "token":
|
||||
if clearValue == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the context for the given token and cancel its context
|
||||
index, err := c.db.Get(cachememdb.IndexNameToken, clearValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("cancelling context of index attached to token")
|
||||
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
|
||||
case "token_accessor", "lease":
|
||||
// Get the cached index and cancel the corresponding renewer context
|
||||
index, err := c.db.Get(clearType, clearValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("cancelling context of index attached to accessor")
|
||||
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
|
||||
case "all":
|
||||
// Cancel the base context which triggers all the goroutines to
|
||||
// stop and evict entries from cache.
|
||||
c.logger.Debug("cancelling base context")
|
||||
c.baseCtxInfo.CancelFunc()
|
||||
|
||||
// Reset the base context
|
||||
baseCtx, baseCancel := context.WithCancel(ctx)
|
||||
c.baseCtxInfo = &ContextInfo{
|
||||
Ctx: baseCtx,
|
||||
CancelFunc: baseCancel,
|
||||
}
|
||||
|
||||
// Reset the memdb instance
|
||||
if err := c.db.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return errInvalidType
|
||||
}
|
||||
|
||||
c.logger.Debug("successfully cleared matching cache entries")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRevocationRequest checks whether the originating request is a
|
||||
// revocation request, and if so perform applicable cache cleanups.
|
||||
// Returns true is this is a revocation request.
|
||||
func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) {
|
||||
// Lease and token revocations return 204's on success. Fast-path if that's
|
||||
// not the case.
|
||||
if resp.Response.StatusCode != http.StatusNoContent {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
_, path := deriveNamespaceAndRevocationPath(req)
|
||||
|
||||
switch {
|
||||
case path == vaultPathTokenRevoke:
|
||||
// Get the token from the request body
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
tokenRaw, ok := jsonBody["token"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get token from request body")
|
||||
}
|
||||
token, ok := tokenRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected token in the request body to be string")
|
||||
}
|
||||
|
||||
// Clear the cache entry associated with the token and all the other
|
||||
// entries belonging to the leases derived from this token.
|
||||
if err := c.handleCacheClear(ctx, "token", token); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeSelf:
|
||||
// Clear the cache entry associated with the token and all the other
|
||||
// entries belonging to the leases derived from this token.
|
||||
if err := c.handleCacheClear(ctx, "token", req.Token); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeAccessor:
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
accessorRaw, ok := jsonBody["accessor"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get accessor from request body")
|
||||
}
|
||||
accessor, ok := accessorRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected accessor in the request body to be string")
|
||||
}
|
||||
|
||||
if err := c.handleCacheClear(ctx, "token_accessor", accessor); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeOrphan:
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
tokenRaw, ok := jsonBody["token"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get token from request body")
|
||||
}
|
||||
token, ok := tokenRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected token in the request body to be string")
|
||||
}
|
||||
|
||||
// Kill the renewers of all the leases attached to the revoked token
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLeaseToken, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
|
||||
// Kill the renewer of the revoked token
|
||||
index, err := c.db.Get(cachememdb.IndexNameToken, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if index == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Indicate the renewer goroutine for this index to return. This will
|
||||
// not affect the child tokens because the context is not getting
|
||||
// cancelled.
|
||||
close(index.RenewCtxInfo.DoneCh)
|
||||
|
||||
// Clear the parent references of the revoked token in the entries
|
||||
// belonging to the child tokens of the revoked token.
|
||||
indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.TokenParent = ""
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to persist index", "error", err)
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
case path == vaultPathLeaseRevoke:
|
||||
// TODO: Should lease present in the URL itself be considered here?
|
||||
// Get the lease from the request body
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
leaseIDRaw, ok := jsonBody["lease_id"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get lease_id from request body")
|
||||
}
|
||||
leaseID, ok := leaseIDRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected lease_id the request body to be string")
|
||||
}
|
||||
if err := c.handleCacheClear(ctx, "lease", leaseID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case strings.HasPrefix(path, vaultPathLeaseRevokeForce):
|
||||
// Trim the URL path to get the request path prefix
|
||||
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce)
|
||||
// Get all the cache indexes that use the request path containing the
|
||||
// prefix and cancel the renewer context of each.
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, tokenNSID := namespace.SplitIDFromString(req.Token)
|
||||
for _, index := range indexes {
|
||||
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
|
||||
// Only evict leases that match the token's namespace
|
||||
if tokenNSID == leaseNSID {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(path, vaultPathLeaseRevokePrefix):
|
||||
// Trim the URL path to get the request path prefix
|
||||
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix)
|
||||
// Get all the cache indexes that use the request path containing the
|
||||
// prefix and cancel the renewer context of each.
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, tokenNSID := namespace.SplitIDFromString(req.Token)
|
||||
for _, index := range indexes {
|
||||
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
|
||||
// Only evict leases that match the token's namespace
|
||||
if tokenNSID == leaseNSID {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("triggered caching eviction from revocation request")
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
|
||||
// revocation paths.
|
||||
//
|
||||
// If the path contains a namespace, but it's not a revocation path, it will be
|
||||
// returned as-is, since there's no way to tell where the namespace ends and
|
||||
// where the request path begins purely based off a string.
|
||||
//
|
||||
// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke
|
||||
// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke
|
||||
// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar
|
||||
// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar
|
||||
func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) {
|
||||
namespace := "root/"
|
||||
nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName)
|
||||
if nsHeader != "" {
|
||||
namespace = nsHeader
|
||||
}
|
||||
|
||||
fullPath := req.Request.URL.Path
|
||||
nonVersionedPath := strings.TrimPrefix(fullPath, "/v1")
|
||||
|
||||
for _, pathToCheck := range revocationPaths {
|
||||
// We use strings.Contains here for paths that can contain
|
||||
// vars in the path, e.g. /v1/lease/revoke-prefix/:prefix
|
||||
i := strings.Index(nonVersionedPath, pathToCheck)
|
||||
// If there's no match, move on to the next check
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the index is 0, this is a relative path with no namespace preppended,
|
||||
// so we can break early
|
||||
if i == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// We need to turn /ns1 into ns1/, this makes it easy
|
||||
namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i])
|
||||
|
||||
// If it's root, we replace, otherwise we join
|
||||
if namespace == "root/" {
|
||||
namespace = namespaceInPath
|
||||
} else {
|
||||
namespace = namespace + namespaceInPath
|
||||
}
|
||||
|
||||
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:])
|
||||
}
|
||||
|
||||
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath)
|
||||
}
|
|
@ -0,0 +1,507 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
)
|
||||
|
||||
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
|
||||
t.Helper()
|
||||
|
||||
lc, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
BaseContext: context.Background(),
|
||||
Proxier: newMockProxier(responses),
|
||||
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return lc
|
||||
}
|
||||
|
||||
func TestCache_ComputeIndexID(t *testing.T) {
|
||||
type args struct {
|
||||
req *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
req *SendRequest
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
"2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847",
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := computeIndexID(tt.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, string(tt.want)) {
|
||||
t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_EmptyToken(t *testing.T) {
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Even if the send request doesn't have a token on it, a successful
|
||||
// cacheable response should result in the index properly getting populated
|
||||
// with a token and memdb shouldn't complain while inserting the index.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("expected a non empty response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendCacheable(t *testing.T) {
|
||||
// Emulate 2 responses from the api proxy. One returns a new token and the
|
||||
// other returns a lease.
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`),
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Make a request. A response with a new token is returned to the lease
|
||||
// cache and that will be cached.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Send the same request again to get the cached response
|
||||
sendReq = &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Modify the request a little bit to ensure the second response is
|
||||
// returned to the lease cache. But make sure that the token in the request
|
||||
// is valid.
|
||||
sendReq = &SendRequest{
|
||||
Token: "testtoken",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Make the same request again and ensure that the same reponse is returned
|
||||
// again.
|
||||
sendReq = &SendRequest{
|
||||
Token: "testtoken",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendNonCacheable(t *testing.T) {
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)),
|
||||
},
|
||||
},
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Send a request through the lease cache which is not cacheable (there is
|
||||
// no lease information or auth information in the response)
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Since the response is non-cacheable, the second response will be
|
||||
// returned.
|
||||
sendReq = &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) {
|
||||
// Create the cache
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`),
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Send a request through lease cache which returns a response containing
|
||||
// lease_id. Response will not be cached because it doesn't belong to a
|
||||
// token that is managed by the lease cache.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Verify that the response is not cached by sending the same request and
|
||||
// by expecting a different response.
|
||||
sendReq = &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_HandleCacheClear(t *testing.T) {
|
||||
lc := testNewLeaseCache(t, nil)
|
||||
|
||||
handler := lc.HandleCacheClear(context.Background())
|
||||
ts := httptest.NewServer(handler)
|
||||
defer ts.Close()
|
||||
|
||||
// Test missing body, should return 400
|
||||
resp, err := http.Post(ts.URL, "application/json", nil)
|
||||
if err != nil {
|
||||
t.Fatal()
|
||||
}
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
reqType string
|
||||
reqValue string
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
"invalid_type",
|
||||
"foo",
|
||||
"",
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
"invalid_value",
|
||||
"",
|
||||
"bar",
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
"all",
|
||||
"all",
|
||||
"",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue)
|
||||
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tc.expectedStatusCode != resp.StatusCode {
|
||||
t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *SendRequest
|
||||
wantNamespace string
|
||||
wantRelativePath string
|
||||
}{
|
||||
{
|
||||
"non_revocation_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/mounts",
|
||||
},
|
||||
},
|
||||
},
|
||||
"root/",
|
||||
"/v1/ns1/sys/mounts",
|
||||
},
|
||||
{
|
||||
"non_revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/mounts",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/mounts",
|
||||
},
|
||||
{
|
||||
"non_revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/mounts",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/ns2/sys/mounts",
|
||||
},
|
||||
{
|
||||
"revocation_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/leases/revoke",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/leases/revoke",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_relative_partial_ns",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/leases/revoke",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/ns2/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_partial_ns",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/ns2/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req)
|
||||
if gotNamespace != tt.wantNamespace {
|
||||
t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace)
|
||||
}
|
||||
if gotRelativePath != tt.wantRelativePath {
|
||||
t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/server"
|
||||
"github.com/hashicorp/vault/helper/reload"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func ServerListener(lnConfig *config.Listener, logger io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
switch lnConfig.Type {
|
||||
case "unix":
|
||||
return unixSocketListener(lnConfig.Config, logger, ui)
|
||||
case "tcp":
|
||||
return tcpListener(lnConfig.Config, logger, ui)
|
||||
default:
|
||||
return nil, nil, nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
addr, ok := config["address"].(string)
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"])
|
||||
}
|
||||
|
||||
if addr == "" {
|
||||
return nil, nil, nil, fmt.Errorf("address field should point to socket file path")
|
||||
}
|
||||
|
||||
// Remove the socket file as it shouldn't exist for the domain socket to
|
||||
// work
|
||||
err := os.Remove(addr)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Wrap the listener in rmListener so that the Unix domain socket file is
|
||||
// removed on close.
|
||||
listener = &rmListener{
|
||||
Listener: listener,
|
||||
Path: addr,
|
||||
}
|
||||
|
||||
props := map[string]string{"addr": addr, "tls": "disabled"}
|
||||
|
||||
return listener, props, nil, nil
|
||||
}
|
||||
|
||||
func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
bindProto := "tcp"
|
||||
var addr string
|
||||
addrRaw, ok := config["address"]
|
||||
if !ok {
|
||||
addr = "127.0.0.1:8300"
|
||||
} else {
|
||||
addr = addrRaw.(string)
|
||||
}
|
||||
|
||||
// If they've passed 0.0.0.0, we only want to bind on IPv4
|
||||
// rather than golang's dual stack default
|
||||
if strings.HasPrefix(addr, "0.0.0.0:") {
|
||||
bindProto = "tcp4"
|
||||
}
|
||||
|
||||
ln, err := net.Listen(bindProto, addr)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)}
|
||||
|
||||
props := map[string]string{"addr": addr}
|
||||
|
||||
return server.ListenerWrapTLS(ln, props, config, ui)
|
||||
}
|
||||
|
||||
// rmListener is an implementation of net.Listener that forwards most
|
||||
// calls to the listener but also removes a file as part of the close. We
|
||||
// use this to cleanup the unix domain socket on close.
|
||||
type rmListener struct {
|
||||
net.Listener
|
||||
Path string
|
||||
}
|
||||
|
||||
func (l *rmListener) Close() error {
|
||||
// Close the listener itself
|
||||
if err := l.Listener.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the file
|
||||
return os.Remove(l.Path)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// SendRequest is the input for Proxier.Send.
|
||||
type SendRequest struct {
|
||||
Token string
|
||||
Request *http.Request
|
||||
RequestBody []byte
|
||||
}
|
||||
|
||||
// SendResponse is the output from Proxier.Send.
|
||||
type SendResponse struct {
|
||||
Response *api.Response
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// Proxier is the interface implemented by different components that are
|
||||
// responsible for performing specific tasks, such as caching and proxying. All
|
||||
// these tasks combined together would serve the request received by the agent.
|
||||
type Proxier interface {
|
||||
Send(ctx context.Context, req *SendRequest) (*SendResponse, error)
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mockProxier is a mock implementation of the Proxier interface, used for testing purposes.
|
||||
// The mock will return the provided responses every time it reaches its Send method, up to
|
||||
// the last provided response. This lets tests control what the next/underlying Proxier layer
|
||||
// might expect to return.
|
||||
type mockProxier struct {
|
||||
proxiedResponses []*SendResponse
|
||||
responseIndex int
|
||||
}
|
||||
|
||||
func newMockProxier(responses []*SendResponse) *mockProxier {
|
||||
return &mockProxier{
|
||||
proxiedResponses: responses,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
if p.responseIndex >= len(p.proxiedResponses) {
|
||||
return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses))
|
||||
}
|
||||
resp := p.proxiedResponses[p.responseIndex]
|
||||
|
||||
p.responseIndex++
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *mockProxier) ResponseIndex() int {
|
||||
return p.responseIndex
|
||||
}
|
|
@ -0,0 +1,280 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
||||
"github.com/hashicorp/vault/command/agent/auth"
|
||||
agentapprole "github.com/hashicorp/vault/command/agent/auth/approle"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestCache_UsingAutoAuthToken(t *testing.T) {
|
||||
var err error
|
||||
logger := logging.NewVaultLogger(log.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: log.NewNullLogger(),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"approle": credAppRole.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
cores := cluster.Cores
|
||||
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
client := cores[0].Client
|
||||
|
||||
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
||||
os.Setenv(api.EnvVaultAddress, client.Address())
|
||||
|
||||
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
|
||||
Type: "approle",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/approle/role/test1", map[string]interface{}{
|
||||
"bind_secret_id": "true",
|
||||
"token_ttl": "3s",
|
||||
"token_max_ttl": "10s",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secretID1 := resp.Data["secret_id"].(string)
|
||||
|
||||
resp, err = client.Logical().Read("auth/approle/role/test1/role-id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
roleID1 := resp.Data["role_id"].(string)
|
||||
|
||||
rolef, err := ioutil.TempFile("", "auth.role-id.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
role := rolef.Name()
|
||||
rolef.Close() // WriteFile doesn't need it open
|
||||
defer os.Remove(role)
|
||||
t.Logf("input role_id_file_path: %s", role)
|
||||
|
||||
secretf, err := ioutil.TempFile("", "auth.secret-id.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secret := secretf.Name()
|
||||
secretf.Close()
|
||||
defer os.Remove(secret)
|
||||
t.Logf("input secret_id_file_path: %s", secret)
|
||||
|
||||
// We close these right away because we're just basically testing
|
||||
// permissions and finding a usable file name
|
||||
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out := ouf.Name()
|
||||
ouf.Close()
|
||||
os.Remove(out)
|
||||
t.Logf("output: %s", out)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
timer := time.AfterFunc(30*time.Second, func() {
|
||||
cancelFunc()
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
conf := map[string]interface{}{
|
||||
"role_id_file_path": role,
|
||||
"secret_id_file_path": secret,
|
||||
"remove_secret_id_file_after_reading": true,
|
||||
}
|
||||
|
||||
am, err := agentapprole.NewApproleAuthMethod(&auth.AuthConfig{
|
||||
Logger: logger.Named("auth.approle"),
|
||||
MountPath: "auth/approle",
|
||||
Config: conf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ahConfig := &auth.AuthHandlerConfig{
|
||||
Logger: logger.Named("auth.handler"),
|
||||
Client: client,
|
||||
}
|
||||
ah := auth.NewAuthHandler(ahConfig)
|
||||
go ah.Run(ctx, am)
|
||||
defer func() {
|
||||
<-ah.DoneCh
|
||||
}()
|
||||
|
||||
config := &sink.SinkConfig{
|
||||
Logger: logger.Named("sink.file"),
|
||||
Config: map[string]interface{}{
|
||||
"path": out,
|
||||
},
|
||||
}
|
||||
fs, err := file.NewFileSink(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config.Sink = fs
|
||||
|
||||
ss := sink.NewSinkServer(&sink.SinkServerConfig{
|
||||
Logger: logger.Named("sink.server"),
|
||||
Client: client,
|
||||
})
|
||||
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
|
||||
defer func() {
|
||||
<-ss.DoneCh
|
||||
}()
|
||||
|
||||
// This has to be after the other defers so it happens first
|
||||
defer cancelFunc()
|
||||
|
||||
// Check that no sink file exists
|
||||
_, err = os.Lstat(out)
|
||||
if err == nil {
|
||||
t.Fatal("expected err")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatal("expected notexist err")
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test role 1", "path", role)
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test secret 1", "path", secret)
|
||||
}
|
||||
|
||||
getToken := func() string {
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
if time.Now().After(timeout) {
|
||||
t.Fatal("did not find a written token after timeout")
|
||||
}
|
||||
val, err := ioutil.ReadFile(out)
|
||||
if err == nil {
|
||||
os.Remove(out)
|
||||
if len(val) == 0 {
|
||||
t.Fatal("written token was empty")
|
||||
}
|
||||
|
||||
_, err = os.Stat(secret)
|
||||
if err == nil {
|
||||
t.Fatal("secret file exists but was supposed to be removed")
|
||||
}
|
||||
|
||||
client.SetToken(string(val))
|
||||
_, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(val)
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("auto-auth token: %q", getToken())
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer listener.Close()
|
||||
|
||||
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, true, client))
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
|
||||
testClient, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for listeners to come up
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
resp, err = testClient.Logical().Read("auth/token/lookup-self")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("failed to use the auto-auth token to perform lookup-self")
|
||||
}
|
||||
}
|
|
@ -22,6 +22,17 @@ type Config struct {
|
|||
AutoAuth *AutoAuth `hcl:"auto_auth"`
|
||||
ExitAfterAuth bool `hcl:"exit_after_auth"`
|
||||
PidFile string `hcl:"pid_file"`
|
||||
Cache *Cache `hcl:"cache"`
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
UseAutoAuthToken bool `hcl:"use_auto_auth_token"`
|
||||
Listeners []*Listener `hcl:"listeners"`
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
Type string
|
||||
Config map[string]interface{}
|
||||
}
|
||||
|
||||
type AutoAuth struct {
|
||||
|
@ -91,9 +102,102 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
|
|||
return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err)
|
||||
}
|
||||
|
||||
err = parseCache(&result, list)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func parseCache(result *Config, list *ast.ObjectList) error {
|
||||
name := "cache"
|
||||
|
||||
cacheList := list.Filter(name)
|
||||
if len(cacheList.Items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(cacheList.Items) > 1 {
|
||||
return fmt.Errorf("one and only one %q block is required", name)
|
||||
}
|
||||
|
||||
item := cacheList.Items[0]
|
||||
|
||||
var c Cache
|
||||
err := hcl.DecodeObject(&c, item.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result.Cache = &c
|
||||
|
||||
subs, ok := item.Val.(*ast.ObjectType)
|
||||
if !ok {
|
||||
return fmt.Errorf("could not parse %q as an object", name)
|
||||
}
|
||||
subList := subs.List
|
||||
|
||||
err = parseListeners(result, subList)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseListeners(result *Config, list *ast.ObjectList) error {
|
||||
name := "listener"
|
||||
|
||||
listenerList := list.Filter(name)
|
||||
if len(listenerList.Items) < 1 {
|
||||
return fmt.Errorf("at least one %q block is required", name)
|
||||
}
|
||||
|
||||
var listeners []*Listener
|
||||
for _, item := range listenerList.Items {
|
||||
var lnConfig map[string]interface{}
|
||||
err := hcl.DecodeObject(&lnConfig, item.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lnType string
|
||||
switch {
|
||||
case lnConfig["type"] != nil:
|
||||
lnType = lnConfig["type"].(string)
|
||||
delete(lnConfig, "type")
|
||||
case len(item.Keys) == 1:
|
||||
lnType = strings.ToLower(item.Keys[0].Token.Value().(string))
|
||||
default:
|
||||
return errors.New("listener type must be specified")
|
||||
}
|
||||
|
||||
switch lnType {
|
||||
case "unix":
|
||||
// Don't accept TLS connection information for unix domain socket
|
||||
// listener. Maybe something to support in future.
|
||||
unixLnConfig := map[string]interface{}{
|
||||
"tls_disable": true,
|
||||
}
|
||||
unixLnConfig["address"] = lnConfig["address"]
|
||||
lnConfig = unixLnConfig
|
||||
case "tcp":
|
||||
default:
|
||||
return fmt.Errorf("invalid listener type %q", lnType)
|
||||
}
|
||||
|
||||
listeners = append(listeners, &Listener{
|
||||
Type: lnType,
|
||||
Config: lnConfig,
|
||||
})
|
||||
}
|
||||
|
||||
result.Cache.Listeners = listeners
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAutoAuth(result *Config, list *ast.ObjectList) error {
|
||||
name := "auto_auth"
|
||||
|
||||
|
|
|
@ -10,6 +10,80 @@ import (
|
|||
"github.com/hashicorp/vault/helper/logging"
|
||||
)
|
||||
|
||||
func TestLoadConfigFile_AgentCache(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
AutoAuth: &AutoAuth{
|
||||
Method: &Method{
|
||||
Type: "aws",
|
||||
WrapTTL: 300 * time.Second,
|
||||
MountPath: "auth/aws",
|
||||
Config: map[string]interface{}{
|
||||
"role": "foobar",
|
||||
},
|
||||
},
|
||||
Sinks: []*Sink{
|
||||
&Sink{
|
||||
Type: "file",
|
||||
DHType: "curve25519",
|
||||
DHPath: "/tmp/file-foo-dhpath",
|
||||
AAD: "foobar",
|
||||
Config: map[string]interface{}{
|
||||
"path": "/tmp/file-foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Cache: &Cache{
|
||||
UseAutoAuthToken: true,
|
||||
Listeners: []*Listener{
|
||||
&Listener{
|
||||
Type: "unix",
|
||||
Config: map[string]interface{}{
|
||||
"address": "/path/to/socket",
|
||||
"tls_disable": true,
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "tcp",
|
||||
Config: map[string]interface{}{
|
||||
"address": "127.0.0.1:8300",
|
||||
"tls_disable": true,
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "tcp",
|
||||
Config: map[string]interface{}{
|
||||
"address": "127.0.0.1:8400",
|
||||
"tls_key_file": "/path/to/cakey.pem",
|
||||
"tls_cert_file": "/path/to/cacert.pem",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
PidFile: "./pidfile",
|
||||
}
|
||||
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFile(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
auto_auth {
|
||||
method {
|
||||
type = "aws"
|
||||
wrap_ttl = 300
|
||||
config = {
|
||||
role = "foobar"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "/tmp/file-foo"
|
||||
}
|
||||
aad = "foobar"
|
||||
dh_type = "curve25519"
|
||||
dh_path = "/tmp/file-foo-dhpath"
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener {
|
||||
type = "unix"
|
||||
address = "/path/to/socket"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener {
|
||||
type = "tcp"
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener {
|
||||
type = "tcp"
|
||||
address = "127.0.0.1:8400"
|
||||
tls_key_file = "/path/to/cakey.pem"
|
||||
tls_cert_file = "/path/to/cacert.pem"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
auto_auth {
|
||||
method {
|
||||
type = "aws"
|
||||
wrap_ttl = 300
|
||||
config = {
|
||||
role = "foobar"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "/tmp/file-foo"
|
||||
}
|
||||
aad = "foobar"
|
||||
dh_type = "curve25519"
|
||||
dh_path = "/tmp/file-foo-dhpath"
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener "unix" {
|
||||
address = "/path/to/socket"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8400"
|
||||
tls_key_file = "/path/to/cakey.pem"
|
||||
tls_cert_file = "/path/to/cacert.pem"
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
|
||||
|
@ -30,6 +31,188 @@ func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCo
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_Cache_UnixListener(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(hclog.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
Logger: logger.Named("core"),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"jwt": vaultjwt.Factory,
|
||||
},
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
||||
os.Setenv(api.EnvVaultAddress, client.Address())
|
||||
|
||||
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
// Setup Vault
|
||||
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
|
||||
Type: "jwt",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
|
||||
"bound_issuer": "https://team-vault.auth0.com/",
|
||||
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
|
||||
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
"bound_audiences": "https://vault.plugin.auth.jwt.test",
|
||||
"user_claim": "https://vault/user",
|
||||
"groups_claim": "https://vault/groups",
|
||||
"policies": "test",
|
||||
"period": "3s",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
inf, err := ioutil.TempFile("", "auth.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
in := inf.Name()
|
||||
inf.Close()
|
||||
os.Remove(in)
|
||||
t.Logf("input: %s", in)
|
||||
|
||||
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sink1 := sink1f.Name()
|
||||
sink1f.Close()
|
||||
os.Remove(sink1)
|
||||
t.Logf("sink1: %s", sink1)
|
||||
|
||||
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sink2 := sink2f.Name()
|
||||
sink2f.Close()
|
||||
os.Remove(sink2)
|
||||
t.Logf("sink2: %s", sink2)
|
||||
|
||||
conff, err := ioutil.TempFile("", "conf.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conf := conff.Name()
|
||||
conff.Close()
|
||||
os.Remove(conf)
|
||||
t.Logf("config: %s", conf)
|
||||
|
||||
jwtToken, _ := agent.GetTestJWT(t)
|
||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test jwt", "path", in)
|
||||
}
|
||||
|
||||
socketff, err := ioutil.TempFile("", "cache.socket.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
socketf := socketff.Name()
|
||||
socketff.Close()
|
||||
os.Remove(socketf)
|
||||
t.Logf("socketf: %s", socketf)
|
||||
|
||||
config := `
|
||||
auto_auth {
|
||||
method {
|
||||
type = "jwt"
|
||||
config = {
|
||||
role = "test"
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
|
||||
sink "file" {
|
||||
config = {
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener "unix" {
|
||||
address = "%s"
|
||||
tls_disable = true
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
config = fmt.Sprintf(config, in, sink1, sink2, socketf)
|
||||
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test config", "path", conf)
|
||||
}
|
||||
|
||||
_, cmd := testAgentCommand(t, logger)
|
||||
cmd.client = client
|
||||
|
||||
// Kill the command 5 seconds after it starts
|
||||
go func() {
|
||||
select {
|
||||
case <-cmd.ShutdownCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
cmd.ShutdownCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
|
||||
|
||||
// Create a client that talks to the agent
|
||||
os.Setenv(api.EnvVaultAgentAddress, socketf)
|
||||
testClient, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
|
||||
|
||||
// Start the agent
|
||||
go cmd.Run([]string{"-config", conf})
|
||||
|
||||
// Give some time for the auto-auth to complete
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Invoke lookup self through the agent
|
||||
secret, err := testClient.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" {
|
||||
t.Fatalf("failed to perform lookup self through agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitAfterAuth(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(hclog.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
|
|
|
@ -39,6 +39,7 @@ type BaseCommand struct {
|
|||
flagsOnce sync.Once
|
||||
|
||||
flagAddress string
|
||||
flagAgentAddress string
|
||||
flagCACert string
|
||||
flagCAPath string
|
||||
flagClientCert string
|
||||
|
@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) {
|
|||
if c.flagAddress != "" {
|
||||
config.Address = c.flagAddress
|
||||
}
|
||||
if c.flagAgentAddress != "" {
|
||||
config.Address = c.flagAgentAddress
|
||||
}
|
||||
|
||||
if c.flagOutputCurlString {
|
||||
config.OutputCurlString = c.flagOutputCurlString
|
||||
|
@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
|
|||
}
|
||||
f.StringVar(addrStringVar)
|
||||
|
||||
agentAddrStringVar := &StringVar{
|
||||
Name: "agent-address",
|
||||
Target: &c.flagAgentAddress,
|
||||
EnvVar: "VAULT_AGENT_ADDR",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Address of the Agent.",
|
||||
}
|
||||
f.StringVar(agentAddrStringVar)
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "ca-cert",
|
||||
Target: &c.flagCACert,
|
||||
|
|
|
@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List
|
|||
return newLn, nil
|
||||
}
|
||||
|
||||
func listenerWrapTLS(
|
||||
func ListenerWrapTLS(
|
||||
ln net.Listener,
|
||||
props map[string]string,
|
||||
config map[string]interface{},
|
||||
|
|
|
@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
|
|||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
|
||||
ln = TCPKeepAliveListener{ln.(*net.TCPListener)}
|
||||
|
||||
ln, err = listenerWrapProxy(ln, config)
|
||||
if err != nil {
|
||||
|
@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
|
|||
config["x_forwarded_for_reject_not_authorized"] = true
|
||||
}
|
||||
|
||||
return listenerWrapTLS(ln, props, config, ui)
|
||||
return ListenerWrapTLS(ln, props, config, ui)
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
//
|
||||
// This is copied directly from the Go source code.
|
||||
type tcpKeepAliveListener struct {
|
||||
type TCPKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue