113 lines
2.8 KiB
Go
113 lines
2.8 KiB
Go
|
// Copyright (c) HashiCorp, Inc.
|
||
|
// SPDX-License-Identifier: MPL-2.0
|
||
|
|
||
|
package api
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func TestSSH_CreateTLSClient(t *testing.T) {
|
||
|
// load the default configuration
|
||
|
config, err := LoadSSHHelperConfig("./test-fixtures/agent_config.hcl")
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("error loading agent's config file: %s", err))
|
||
|
}
|
||
|
|
||
|
client, err := config.NewClient()
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("error creating the client: %s", err))
|
||
|
}
|
||
|
|
||
|
// Provide a certificate and enforce setting of transport
|
||
|
config.CACert = "./test-fixtures/vault.crt"
|
||
|
|
||
|
client, err = config.NewClient()
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("error creating the client: %s", err))
|
||
|
}
|
||
|
if client.config.HttpClient.Transport == nil {
|
||
|
panic(fmt.Sprintf("error creating client with TLS transport"))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSSH_CreateTLSClient_tlsServerName(t *testing.T) {
|
||
|
// Ensure that the HTTP client is associated with the configured TLS server name.
|
||
|
tlsServerName := "tls.server.name"
|
||
|
|
||
|
config, err := ParseSSHHelperConfig(fmt.Sprintf(`
|
||
|
vault_addr = "1.2.3.4"
|
||
|
tls_server_name = "%s"
|
||
|
`, tlsServerName))
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("error loading config: %s", err))
|
||
|
}
|
||
|
|
||
|
client, err := config.NewClient()
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("error creating the client: %s", err))
|
||
|
}
|
||
|
|
||
|
actualTLSServerName := client.config.HttpClient.Transport.(*http.Transport).TLSClientConfig.ServerName
|
||
|
if actualTLSServerName != tlsServerName {
|
||
|
panic(fmt.Sprintf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, actualTLSServerName))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseSSHHelperConfig(t *testing.T) {
|
||
|
config, err := ParseSSHHelperConfig(`
|
||
|
vault_addr = "1.2.3.4"
|
||
|
`)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if config.SSHMountPoint != SSHHelperDefaultMountPoint {
|
||
|
t.Errorf("expected %q to be %q", config.SSHMountPoint, SSHHelperDefaultMountPoint)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseSSHHelperConfig_missingVaultAddr(t *testing.T) {
|
||
|
_, err := ParseSSHHelperConfig("")
|
||
|
if err == nil {
|
||
|
t.Fatal("expected error")
|
||
|
}
|
||
|
|
||
|
if !strings.Contains(err.Error(), `missing config "vault_addr"`) {
|
||
|
t.Errorf("bad error: %s", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseSSHHelperConfig_badKeys(t *testing.T) {
|
||
|
_, err := ParseSSHHelperConfig(`
|
||
|
vault_addr = "1.2.3.4"
|
||
|
nope = "bad"
|
||
|
`)
|
||
|
if err == nil {
|
||
|
t.Fatal("expected error")
|
||
|
}
|
||
|
|
||
|
if !strings.Contains(err.Error(), `ssh_helper: invalid key "nope" on line 3`) {
|
||
|
t.Errorf("bad error: %s", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseSSHHelperConfig_tlsServerName(t *testing.T) {
|
||
|
tlsServerName := "tls.server.name"
|
||
|
|
||
|
config, err := ParseSSHHelperConfig(fmt.Sprintf(`
|
||
|
vault_addr = "1.2.3.4"
|
||
|
tls_server_name = "%s"
|
||
|
`, tlsServerName))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if config.TLSServerName != tlsServerName {
|
||
|
t.Errorf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, config.TLSServerName)
|
||
|
}
|
||
|
}
|