open-vault/api/client_test.go

408 lines
9.8 KiB
Go
Raw Normal View History

2015-03-11 22:42:08 +00:00
package api
import (
2015-04-20 18:30:35 +00:00
"bytes"
"io"
2015-03-11 22:42:08 +00:00
"net/http"
"os"
"strings"
2015-03-11 22:42:08 +00:00
"testing"
"github.com/hashicorp/vault/sdk/helper/consts"
2015-03-11 22:42:08 +00:00
)
func init() {
// Ensure our special envvars are not present
2015-04-23 15:46:22 +00:00
os.Setenv("VAULT_ADDR", "")
os.Setenv("VAULT_TOKEN", "")
}
func TestDefaultConfig_envvar(t *testing.T) {
2015-04-23 15:46:22 +00:00
os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
defer os.Setenv("VAULT_ADDR", "")
config := DefaultConfig()
if config.Address != "https://vault.mycompany.com" {
t.Fatalf("bad: %s", config.Address)
}
os.Setenv("VAULT_TOKEN", "testing")
defer os.Setenv("VAULT_TOKEN", "")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
if token := client.Token(); token != "testing" {
t.Fatalf("bad: %s", token)
}
}
func TestClientDefaultHttpClient(t *testing.T) {
_, err := NewClient(&Config{
HttpClient: http.DefaultClient,
})
if err != nil {
t.Fatal(err)
}
}
func TestClientNilConfig(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
}
func TestClientSetAddress(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
}
2015-03-11 22:42:08 +00:00
func TestClientToken(t *testing.T) {
tokenValue := "foo"
2015-08-22 00:36:19 +00:00
handler := func(w http.ResponseWriter, req *http.Request) {}
2015-03-11 22:42:08 +00:00
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
2015-08-22 00:36:19 +00:00
client.SetToken(tokenValue)
2015-03-11 22:42:08 +00:00
// Verify the token is set
if v := client.Token(); v != tokenValue {
t.Fatalf("bad: %s", v)
}
client.ClearToken()
if v := client.Token(); v != "" {
t.Fatalf("bad: %s", v)
}
}
2015-03-31 04:20:23 +00:00
func TestClientHostHeader(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Host))
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatal(err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
t.Fatalf("Bad address: %s", buf.String())
}
}
func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken("foo")
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatal(err)
}
client.SetToken("foo\u007f")
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
if err == nil || !strings.Contains(err.Error(), "printable") {
t.Fatalf("expected error due to bad token")
}
}
2015-04-20 18:30:35 +00:00
func TestClientRedirect(t *testing.T) {
primary := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
defer ln.Close()
standby := func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", config.Address)
w.WriteHeader(307)
}
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
defer ln2.Close()
client, err := NewClient(config2)
if err != nil {
t.Fatalf("err: %s", err)
}
2015-09-03 14:36:59 +00:00
// Set the token manually
2015-04-20 18:30:35 +00:00
client.SetToken("foo")
// Do a raw "/" request
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != "test" {
t.Fatalf("Bad: %s", buf.String())
}
}
func TestClientEnvSettings(t *testing.T) {
cwd, _ := os.Getwd()
oldCACert := os.Getenv(EnvVaultCACert)
oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5")
defer os.Setenv(EnvVaultCACert, oldCACert)
defer os.Setenv(EnvVaultCAPath, oldCAPath)
defer os.Setenv(EnvVaultClientCert, oldClientCert)
defer os.Setenv(EnvVaultClientKey, oldClientKey)
defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if len(tlsConfig.RootCAs.Subjects()) == 0 {
t.Fatalf("bad: expected a cert pool with at least one subject")
}
2017-11-13 15:35:36 +00:00
if tlsConfig.GetClientCertificate == nil {
2017-11-12 17:34:56 +00:00
t.Fatalf("bad: expected client tls config to have a certificate getter")
}
if tlsConfig.InsecureSkipVerify != true {
2016-04-13 19:38:29 +00:00
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
}
func TestClientDeprecatedEnvSettings(t *testing.T) {
oldInsecure := os.Getenv(EnvVaultInsecure)
os.Setenv(EnvVaultInsecure, "true")
defer os.Setenv(EnvVaultInsecure, oldInsecure)
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
}
func TestClientEnvNamespace(t *testing.T) {
var seenNamespace string
handler := func(w http.ResponseWriter, req *http.Request) {
seenNamespace = req.Header.Get(consts.NamespaceHeaderName)
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
oldVaultNamespace := os.Getenv(EnvVaultNamespace)
defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
os.Setenv(EnvVaultNamespace, "test")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
_, err = client.RawRequest(client.NewRequest("GET", "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if seenNamespace != "test" {
t.Fatalf("Bad: %s", seenNamespace)
}
}
2018-05-11 14:42:06 +00:00
func TestParsingRateAndBurst(t *testing.T) {
var (
correctFormat = "400:400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
2018-05-11 14:42:06 +00:00
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
2018-05-11 14:42:06 +00:00
}
}
func TestParsingRateOnly(t *testing.T) {
var (
correctFormat = "400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
2018-05-11 14:42:06 +00:00
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
2018-05-11 14:42:06 +00:00
}
}
func TestParsingErrorCase(t *testing.T) {
var incorrectFormat = "foobar"
var _, _, err = parseRateLimit(incorrectFormat)
2018-05-11 14:42:06 +00:00
if err == nil {
t.Error("Expected error, found no error")
}
}
func TestClientTimeoutSetting(t *testing.T) {
oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
os.Setenv(EnvVaultClientTimeout, "10")
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
config := DefaultConfig()
config.ReadEnvironment()
_, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return rt(r)
}
func TestClientNonTransportRoundTripper(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
})
if err != nil {
t.Fatal(err)
}
}
func TestClone(t *testing.T) {
client1, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
client2, err2 := client1.Clone()
if err2 != nil {
t.Fatalf("Clone failed: %v", err2)
}
_ = client2
}
func TestSetHeadersRaceSafe(t *testing.T) {
client, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
start := make(chan interface{})
done := make(chan interface{})
testPairs := map[string]string{
"soda": "rootbeer",
"veggie": "carrots",
"fruit": "apples",
"color": "red",
"protein": "egg",
}
for key, value := range testPairs {
tmpKey := key
tmpValue := value
go func() {
<-start
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
// headerCopy := client.Header()
// headerCopy.AddHeader(tmpKey, tmpValue)
// client.SetHeader(headerCopy)
client.AddHeader(tmpKey, tmpValue)
done <- true
}()
}
// Start everyone at once.
close(start)
// Wait until everyone is done.
for i := 0; i < len(testPairs); i++ {
<-done
}
// Check that all the test pairs are in the resulting
// headers.
resultingHeaders := client.Headers()
for key, value := range testPairs {
if resultingHeaders.Get(key) != value {
t.Fatal("expected " + value + " for " + key)
}
}
}