open-vault/command/agent/cache/lease_cache_test.go
2019-07-18 21:10:15 -04:00

512 lines
13 KiB
Go

package cache
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
)
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
t.Helper()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: newMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}
return lc
}
func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
}
tests := []struct {
name string
req *SendRequest
want string
wantErr bool
}{
{
"basic",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "test",
},
},
},
"7b5db388f211fd9edca8c6c254831fb01ad4e6fe624dbb62711f256b5e803717",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := computeIndexID(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, string(tt.want)) {
t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want))
}
})
}
}
func TestLeaseCache_EmptyToken(t *testing.T) {
responses := []*SendResponse{
newTestSendResponse(http.StatusCreated, `{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
}
lc := testNewLeaseCache(t, responses)
// Even if the send request doesn't have a token on it, a successful
// cacheable response should result in the index properly getting populated
// with a token and memdb shouldn't complain while inserting the index.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatalf("expected a non empty response")
}
}
func TestLeaseCache_SendCacheable(t *testing.T) {
// Emulate 2 responses from the api proxy. One returns a new token and the
// other returns a lease.
responses := []*SendResponse{
newTestSendResponse(http.StatusCreated, `{"auth": {"client_token": "testtoken", "renewable": true}}`),
newTestSendResponse(http.StatusOK, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}}`),
}
lc := testNewLeaseCache(t, responses)
// Register an token so that the token and lease requests are cached
lc.RegisterAutoAuthToken("autoauthtoken")
// Make a request. A response with a new token is returned to the lease
// cache and that will be cached.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Send the same request again to get the cached response
sendReq = &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Modify the request a little bit to ensure the second response is
// returned to the lease cache.
sendReq = &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Make the same request again and ensure that the same response is returned
// again.
sendReq = &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestLeaseCache_SendNonCacheable(t *testing.T) {
responses := []*SendResponse{
newTestSendResponse(http.StatusOK, `{"value": "output"}`),
newTestSendResponse(http.StatusNotFound, `{"value": "invalid"}`),
newTestSendResponse(http.StatusOK, `<html>Hello</html>`),
newTestSendResponse(http.StatusTemporaryRedirect, ""),
}
lc := testNewLeaseCache(t, responses)
// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[0].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Since the response is non-cacheable, the second response will be
// returned.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[1].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Since the response is non-cacheable, the third response will be
// returned.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", "http://example.com", nil),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[2].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Since the response is non-cacheable, the fourth response will be
// returned.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", "http://example.com", nil),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[3].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestLeaseCache_SendNonCacheableNonTokenLease(t *testing.T) {
// Create the cache
responses := []*SendResponse{
newTestSendResponse(http.StatusOK, `{"value": "output", "lease_id": "foo"}`),
newTestSendResponse(http.StatusCreated, `{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
}
lc := testNewLeaseCache(t, responses)
// Send a request through lease cache which returns a response containing
// lease_id. Response will not be cached because it doesn't belong to a
// token that is managed by the lease cache.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[0].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
idx, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath)
if err != nil {
t.Fatal(err)
}
if idx != nil {
t.Fatalf("expected nil entry, got: %#v", idx)
}
// Verify that the response is not cached by sending the same request and
// by expecting a different response.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response, responses[1].Response); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
idx, err = lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath)
if err != nil {
t.Fatal(err)
}
if idx != nil {
t.Fatalf("expected nil entry, got: %#v", idx)
}
}
func TestLeaseCache_HandleCacheClear(t *testing.T) {
lc := testNewLeaseCache(t, nil)
handler := lc.HandleCacheClear(context.Background())
ts := httptest.NewServer(handler)
defer ts.Close()
// Test missing body, should return 400
resp, err := http.Post(ts.URL, "application/json", nil)
if err != nil {
t.Fatal()
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode)
}
testCases := []struct {
name string
reqType string
reqValue string
expectedStatusCode int
}{
{
"invalid_type",
"foo",
"",
http.StatusBadRequest,
},
{
"invalid_value",
"",
"bar",
http.StatusBadRequest,
},
{
"all",
"all",
"",
http.StatusOK,
},
{
"by_request_path",
"request_path",
"foo",
http.StatusOK,
},
{
"by_token",
"token",
"foo",
http.StatusOK,
},
{
"by_lease",
"lease",
"foo",
http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue)
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody))
if err != nil {
t.Fatal(err)
}
if tc.expectedStatusCode != resp.StatusCode {
t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode)
}
})
}
}
func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
tests := []struct {
name string
req *SendRequest
wantNamespace string
wantRelativePath string
}{
{
"non_revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/mounts",
},
},
},
"root/",
"/v1/ns1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/ns2/sys/mounts",
},
{
"revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke",
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke",
},
{
"revocation_prefix_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke-prefix/foo",
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke-prefix/foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req)
if gotNamespace != tt.wantNamespace {
t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace)
}
if gotRelativePath != tt.wantRelativePath {
t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath)
}
})
}
}