open-vault/sdk/testing/stepwise/stepwise_test.go
Brian Kassouf 303c2aee7c
Run a more strict formatter over the code (#11312)
* Update tooling

* Run gofumpt

* go mod vendor
2021-04-08 09:43:39 -07:00

449 lines
10 KiB
Go

package stepwise
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
// testTesting is used for testing the legacy testing framework
var testTesting = false
type testRun struct {
expectedTestT *mockT
environment *mockEnvironment
steps []Step
skipTeardown bool
requests *requestCounts
}
// TestStepwise_Run_SkipIfNotAcc tests if the Stepwise Run function skips tests
// if the VAULT_ACC environment variable is not set. This test is seperate from
// the table tests due to the unsetting/re-setting of the environment variable,
// which is assumed/needed for all other tests.
func TestStepwise_Run_SkipIfNotAcc(t *testing.T) {
if err := os.Setenv(TestEnvVar, ""); err != nil {
t.Fatalf("err: %s", err)
}
defer os.Setenv(TestEnvVar, "1")
skipCase := Case{
Environment: new(mockEnvironment),
Steps: []Step{{}},
}
expected := mockT{
SkipCalled: true,
}
testT := new(mockT)
Run(testT, skipCase)
if testT.SkipCalled != expected.SkipCalled {
t.Fatalf("expected SkipCalled (%t), got (%t)", expected.SkipCalled, testT.SkipCalled)
}
}
func TestStepwise_Run_Basic(t *testing.T) {
testRuns := map[string]testRun{
"basic_list": {
steps: []Step{
stepFunc("keys", ListOperation, false),
},
environment: new(mockEnvironment),
requests: &requestCounts{
listRequests: 1,
},
},
"basic_list_read": {
steps: []Step{
stepFunc("keys", ListOperation, false),
stepFunc("keys/name", ReadOperation, false),
},
environment: new(mockEnvironment),
requests: &requestCounts{
listRequests: 1,
readRequests: 1,
revokeRequests: 1,
},
},
"basic_unauth": {
steps: []Step{
stepFuncWithoutAuth("keys", ListOperation, true),
},
expectedTestT: &mockT{
ErrorCalled: true,
},
environment: new(mockEnvironment),
},
"error": {
steps: []Step{
stepFunc("keys", ListOperation, false),
stepFunc("keys/something", ReadOperation, true),
},
expectedTestT: &mockT{
ErrorCalled: true,
},
environment: new(mockEnvironment),
requests: &requestCounts{
listRequests: 1,
},
},
"nil-env": {
expectedTestT: &mockT{
FatalCalled: true,
},
steps: []Step{
stepFunc("keys", ListOperation, false),
},
},
"skipTeardown": {
steps: []Step{
stepFunc("keys", ListOperation, false),
stepFunc("keys/name", ReadOperation, false),
stepFunc("keys/name", ReadOperation, false),
stepFunc("keys/name", DeleteOperation, false),
},
skipTeardown: true,
environment: new(mockEnvironment),
requests: &requestCounts{
listRequests: 1,
readRequests: 2,
revokeRequests: 2,
deleteRequests: 1,
},
},
}
for name, tr := range testRuns {
t.Run(name, func(t *testing.T) {
testT := new(mockT)
expectedT := tr.expectedTestT
if expectedT == nil {
expectedT = new(mockT)
}
testCase := Case{
Steps: tr.steps,
SkipTeardown: tr.skipTeardown,
}
if tr.environment != nil {
testCase.Environment = tr.environment
}
Run(testT, testCase)
if tr.environment == nil && !testT.FatalCalled {
t.Fatal("expected FatalCalled with nil environment, but wasn't")
}
if tr.environment != nil {
if tr.skipTeardown && tr.environment.teardownCalled {
t.Fatal("SkipTeardown is true, but Teardown was called")
}
if !tr.skipTeardown && !tr.environment.teardownCalled {
t.Fatal("SkipTeardown is false, but Teardown was not called")
}
}
if expectedT.ErrorCalled != testT.ErrorCalled {
t.Fatalf("expected ErrorCalled (%t), got (%t)", expectedT.ErrorCalled, testT.ErrorCalled)
}
if tr.requests != nil {
if !reflect.DeepEqual(*tr.requests, tr.environment.requests) {
t.Fatalf("request counts do not match: %#v / %#v", tr.requests, tr.environment.requests)
}
}
})
}
}
type requestCounts struct {
writeRequests int
readRequests int
deleteRequests int
revokeRequests int
listRequests int
}
func TestStepwise_makeRequest(t *testing.T) {
me := new(mockEnvironment)
me.Setup()
testT := new(mockT)
type testRequest struct {
Operation Operation
Path string
ExpectedRequestID string
ExpectErr bool
UnAuth bool
}
testRequests := map[string]testRequest{
"list": {
Operation: ListOperation,
Path: "keys",
ExpectedRequestID: "list-request",
},
"read": {
Operation: ReadOperation,
Path: "keys/name",
ExpectedRequestID: "read-request",
},
"write": {
Operation: WriteOperation,
Path: "keys/name",
ExpectedRequestID: "write-request",
},
"update": {
Operation: UpdateOperation,
Path: "keys/name",
ExpectedRequestID: "write-request",
},
"update_unauth": {
Operation: UpdateOperation,
Path: "keys/name",
UnAuth: true,
ExpectErr: true,
},
"delete": {
Operation: DeleteOperation,
Path: "keys/name",
ExpectedRequestID: "delete-request",
},
"error": {
Operation: ReadOperation,
Path: "error",
ExpectErr: true,
},
}
for name, tc := range testRequests {
t.Run(name, func(t *testing.T) {
step := Step{
Operation: tc.Operation,
Path: tc.Path,
}
if tc.UnAuth {
step.Unauthenticated = tc.UnAuth
}
secret, err := makeRequest(testT, me, step)
if err != nil && !tc.ExpectErr {
t.Fatalf("unexpected error: %s", err)
}
if err == nil && tc.ExpectErr {
t.Fatal("expected error but got none:")
}
if err != nil && tc.ExpectErr {
return
}
if secret.RequestID != tc.ExpectedRequestID {
t.Fatalf("expected (%s), got (%s)", tc.ExpectedRequestID, secret.RequestID)
}
})
}
}
type mockEnvironment struct {
ts *httptest.Server
client *api.Client
l sync.Mutex
teardownCalled bool
requests requestCounts
}
// Setup creates the mock environment, establishing a test HTTP server
func (m *mockEnvironment) Setup() error {
mux := http.NewServeMux()
// LIST
mux.HandleFunc("/v1/test/keys", func(w http.ResponseWriter, req *http.Request) {
checkAuth(w, req)
switch req.Method {
case "GET":
m.requests.listRequests++
respondCommon("list", true, w, req)
default:
w.WriteHeader(http.StatusBadRequest)
}
})
// lease revoke
mux.HandleFunc("/v1/sys/leases/revoke", func(w http.ResponseWriter, req *http.Request) {
checkAuth(w, req)
m.requests.revokeRequests++
w.WriteHeader(http.StatusOK)
})
// READ, DELETE, WRITE
mux.HandleFunc("/v1/test/keys/name", func(w http.ResponseWriter, req *http.Request) {
checkAuth(w, req)
var method string
// indicate if the common response should include a lease id
var excludeLease bool
switch req.Method {
case "GET":
m.requests.readRequests++
method = "read"
case "POST":
case "PUT":
m.requests.writeRequests++
method = "write"
case "DELETE":
m.requests.deleteRequests++
excludeLease = true
method = "delete"
default:
w.WriteHeader(http.StatusBadRequest)
}
respondCommon(method, excludeLease, w, req)
})
// fall through that rejects any other url than "/"
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/" {
http.NotFound(w, req)
return
}
fmt.Fprintf(w, "{}")
})
m.ts = httptest.NewServer(mux)
return nil
}
// respondCommon returns a mock secret with a request ID that matches the
// request method that was used to invoke it. A true Vault server would not
// respond with a request id / lease id for DELETE or REVOKE, but we do that
// here to verify that the makeRequest method translates the Step Operation
// and calls delete/revoke correctly
func respondCommon(id string, excludeLease bool, w http.ResponseWriter, req *http.Request) {
resp := api.Secret{
RequestID: id + "-request",
LeaseID: "lease-id",
}
if excludeLease {
resp.LeaseID = ""
}
out, err := jsonutil.EncodeJSON(resp)
if err != nil {
panic(err)
}
w.Write(out)
}
// Client creates a Vault API client configured to the mock environment's test
// server
func (m *mockEnvironment) Client() (*api.Client, error) {
m.l.Lock()
defer m.l.Unlock()
if m.ts == nil {
return nil, errors.New("client() called but test server is nil")
}
if m.client != nil {
return m.client, nil
}
cfg := api.Config{
Address: m.ts.URL,
HttpClient: cleanhttp.DefaultPooledClient(),
Timeout: time.Second * 5,
MaxRetries: 2,
}
client, err := api.NewClient(&cfg)
if err != nil {
return nil, err
}
// need to set root token here to mimic an actual root token of a cluster
client.SetToken(m.RootToken())
m.client = client
return client, nil
}
func (m *mockEnvironment) Teardown() error {
m.teardownCalled = true
m.ts.Close()
return nil
}
func (m *mockEnvironment) Name() string { return "" }
func (m *mockEnvironment) MountPath() string {
return "/test/"
}
func (m *mockEnvironment) RootToken() string { return "root-token" }
func stepFuncWithoutAuth(path string, operation Operation, shouldError bool) Step {
return stepFuncCommon(path, operation, shouldError, true)
}
func stepFunc(path string, operation Operation, shouldError bool) Step {
return stepFuncCommon(path, operation, shouldError, false)
}
func stepFuncCommon(path string, operation Operation, shouldError bool, unauth bool) Step {
s := Step{
Operation: operation,
Path: path,
Unauthenticated: unauth,
}
if shouldError {
s.Assert = func(resp *api.Secret, err error) error {
return errors.New("some error")
}
}
return s
}
// mockT implements TestT for testing
type mockT struct {
ErrorCalled bool
ErrorArgs []interface{}
FatalCalled bool
FatalArgs []interface{}
SkipCalled bool
SkipArgs []interface{}
f bool
}
func (t *mockT) Error(args ...interface{}) {
t.ErrorCalled = true
t.ErrorArgs = args
t.f = true
}
func (t *mockT) Fatal(args ...interface{}) {
t.FatalCalled = true
t.FatalArgs = args
t.f = true
}
func (t *mockT) Skip(args ...interface{}) {
t.SkipCalled = true
t.SkipArgs = args
t.f = true
}
func (t *mockT) Helper() {}
// validates that X-Vault-Token is set on the requets to the mock endpoints
func checkAuth(w http.ResponseWriter, r *http.Request) {
if token := r.Header.Get("X-Vault-Token"); token == "" {
// not authenticated
w.WriteHeader(http.StatusForbidden)
}
}