Vault SSH: Test case skeleton

This commit is contained in:
Vishal Nayak 2015-07-10 09:56:14 -06:00
parent 3e4b67f90a
commit 3c7dd8611c
6 changed files with 543 additions and 26 deletions

View File

@ -0,0 +1,227 @@
package ssh
import (
"bytes"
"fmt"
"log"
"net"
"os/exec"
"os/user"
"strings"
"testing"
"golang.org/x/crypto/ssh"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)
const (
testCidr = "127.0.0.1/32"
testRoleName = "testRoleName"
testKey = "testKey"
testPublicKey = `
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
`
testPrivateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
-----END RSA PRIVATE KEY-----
`
)
var testIP string
var testPort string
var testUserName string
var testAdminUser string
func init() {
addr, err := startTestServer()
if err != nil {
panic(fmt.Sprintf("Error starting mock server:%s", err))
}
input := strings.Split(addr, ":")
testIP = input[0]
testPort = input[1]
u, err := user.Current()
if err != nil {
panic(fmt.Sprintf("Error getting current username: '%s'", err))
}
testUserName = u.Username
testAdminUser = u.Username
}
func TestSSHBackend(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testNamedKeys(t),
testNewRole(t),
testRoleCreate(t),
},
})
}
func startTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
return "", fmt.Errorf("Error parsing public key")
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
return &ssh.Permissions{}, nil
} else {
return nil, fmt.Errorf("Key does not match")
}
},
}
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
if err != nil {
panic("Error parsing private key")
}
serverConfig.AddHostKey(signer)
soc, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("Error listening to connection")
}
go func() {
for {
conn, err := soc.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
}
defer conn.Close()
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
panic(fmt.Sprintf("Handshaking error: %v", err))
}
go func() {
for chanReq := range chanReqs {
go func(chanReq ssh.NewChannel) {
if chanReq.ChannelType() != "session" {
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
return
}
ch, requests, err := chanReq.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting channel: %s", err))
}
go func(ch ssh.Channel, in <-chan *ssh.Request) {
for req := range in {
executeCommand(ch, req)
}
}(ch, requests)
}(chanReq)
}
sshConn.Close()
}()
}
}()
return soc.Addr().String(), nil
}
func executeCommand(ch ssh.Channel, req *ssh.Request) {
command := string(req.Payload[4:])
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
req.Reply(true, nil)
cmd.Stdout = ch
cmd.Stderr = ch
cmd.Stdin = ch
err := cmd.Start()
if err != nil {
panic(fmt.Sprintf("Error starting the command: '%s'", err))
}
go func() {
_, err := cmd.Process.Wait()
if err != nil {
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
}
ch.Close()
}()
}
func testRoleCreate(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("creds/%s", testRoleName),
Data: map[string]interface{}{
"username": testUserName,
"ip": testIP,
},
Check: func(resp *logical.Response) error {
var d struct {
Key string `mapstructure:"key"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated Key:%s\n", d.Key)
if d.Key == "" {
return fmt.Errorf("Generated key is an empty string")
}
_, err := ssh.ParsePrivateKey([]byte(d.Key))
if err != nil {
return fmt.Errorf("Generated key is invalid")
}
return nil
},
}
}
func testNewRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("roles/%s", testRoleName),
Data: map[string]interface{}{
"key": testKey,
"admin_user": testAdminUser,
"cidr": testCidr,
"port": testPort,
},
}
}
func testNamedKeys(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("keys/%s", testKey),
Data: map[string]interface{}{
"key": testPrivateKey,
},
}
}

View File

@ -2,6 +2,7 @@ package ssh
import (
"fmt"
"log"
"net"
"github.com/hashicorp/vault/logical"
@ -35,6 +36,7 @@ func pathRoleCreate(b *backend) *framework.Path {
func (b *backend) pathRoleCreateWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
log.Printf("Vishal: pathRoleCreateWrite\n")
roleName := d.Get("name").(string)
username := d.Get("username").(string)
ipRaw := d.Get("ip").(string)
@ -92,9 +94,11 @@ func (b *backend) pathRoleCreateWrite(
// Transfer the public key to target machine
err = uploadPublicKeyScp(dynamicPublicKey, username, ip, role.Port, hostKey.Key)
//return nil, nil //TODO remove this
if err != nil {
return nil, err
}
log.Printf("Vishal: uploaded public key file to target\n")
// Add the public key to authorized_keys file in target machine
err = installPublicKeyInTarget(username, ip, role.Port, hostKey.Key)
@ -102,6 +106,7 @@ func (b *backend) pathRoleCreateWrite(
return nil, fmt.Errorf("error adding public key to authorized_keys file in target")
}
log.Printf("Vishal: installed public key file to target\n")
result := b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{
"key": dynamicPrivateKey,
}, map[string]interface{}{

View File

@ -8,6 +8,7 @@ import (
"encoding/pem"
"fmt"
"io"
"log"
"net"
"strings"
@ -36,9 +37,9 @@ func uploadPublicKeyScp(publicKey, username, ip, port, key string) error {
fmt.Fprint(w, "\x00")
w.Close()
}()
if err := session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)); err != nil {
return err
}
log.Printf("Vishal: uploading now\n")
err = session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName))
log.Printf("Vishal: upload completed: err:%s\n", err)
return nil
}
@ -113,22 +114,22 @@ func installPublicKeyInTarget(username, ip, port, hostKey string) error {
}
defer session.Close()
authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username)
authKeysFileName := "~/.ssh/authorized_keys"
tempKeysFileName := "~/temp_authorized_keys"
// Commands to be run on target machine
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName)
catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName)
catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName)
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
//removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
log.Printf(grepCmd)
log.Printf(catCmdRemoveDuplicate)
log.Printf(catCmdAppendNew)
targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
// Run the commands on target machine
if err := session.Run(targetCmd); err != nil {
return err
}
//targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
targetCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew)
session.Run(targetCmd)
return nil
}
@ -143,8 +144,8 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
}
defer session.Close()
authKeysFileName := "/home/" + username + "/.ssh/authorized_keys"
tempKeysFileName := "/home/" + username + "/temp_authorized_keys"
authKeysFileName := "~/.ssh/authorized_keys"
tempKeysFileName := "~/temp_authorized_keys"
// Commands to be run on target machine
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
@ -153,11 +154,7 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd)
// Run the commands in target machine
if err := session.Run(remoteCmd); err != nil {
return err
}
session.Run(remoteCmd)
return nil
}

View File

@ -51,7 +51,7 @@ func (c *SSHCommand) Run(args []string) int {
c.Ui.Error(fmt.Sprintf("Error setting default role: %s", err.Error()))
return 1
}
c.Ui.Output(fmt.Sprintf("Using role[%s]\n", role))
c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role))
}
data := map[string]interface{}{
@ -72,10 +72,14 @@ func (c *SSHCommand) Run(args []string) int {
sshDynamicKeyFileName := fmt.Sprintf("vault_temp_file_%s_%s", username, ip.String())
err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(sshDynamicKey), 0600)
cmd := exec.Command("ssh", "-p", port, "-i", sshDynamicKeyFileName, args[0])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
err = cmd.Run()
sshCmdArgs := []string{"-p", port, "-i", sshDynamicKeyFileName}
sshCmdArgs = append(sshCmdArgs, args...)
sshCmd := exec.Command("ssh", sshCmdArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
err = sshCmd.Run()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err))
}
@ -145,6 +149,8 @@ SSH Options:
skip mentioning the role. It will be chosen by default.
If there are no roless associated with the IP, register
the CIDR block of that IP using the "roles/" endpoint.
-port Port number to use for SSH connection. This defaults to port 22.
`
return strings.TrimSpace(helpText)
}

259
command/ssh_test.go Normal file
View File

@ -0,0 +1,259 @@
package command
import (
"bytes"
"fmt"
"log"
"net"
"os/exec"
"os/user"
"strings"
"testing"
logicalssh "github.com/hashicorp/vault/builtin/logical/ssh"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
"golang.org/x/crypto/ssh"
)
const (
testCidr = "127.0.0.1/32"
testRoleName = "testRoleName"
testKey = "testKey"
testPublicKey = `
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
`
testPrivateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
-----END RSA PRIVATE KEY-----
`
)
var testIP string
var testPort string
var testUserName string
var testAdminUser string
func init() {
addr, err := startTestServer()
if err != nil {
panic(fmt.Sprintf("Error starting mock server:%s", err))
}
input := strings.Split(addr, ":")
testIP = input[0]
testPort = input[1]
//testPort = "22"
u, err := user.Current()
if err != nil {
panic(fmt.Sprintf("Error getting current username: '%s'", err))
}
testUserName = u.Username
testAdminUser = u.Username
//testUserName = "vishal" //TODO: remove this
//testAdminUser = "vishal" //TODO: remove this
}
func TestSSH(t *testing.T) {
err := vault.AddTestLogicalBackend("ssh", logicalssh.Factory)
if err != nil {
t.Fatalf("err: %s", err)
}
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
mountCmd := &MountCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{"-address", addr, "ssh"}
log.Printf("Vishal: mount args: %#v\n", args)
if code := mountCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
client, err := mountCmd.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
mounts, err := client.Sys().ListMounts()
if err != nil {
t.Fatalf("err: %s", err)
}
mount, ok := mounts["ssh/"]
if !ok {
t.Fatal("should have ssh mount")
}
if mount.Type != "ssh" {
t.Fatal("should have ssh type")
}
writeCmd := &WriteCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args = []string{
"-address", addr,
"ssh/keys/" + testKey,
"key=" + testPrivateKey,
}
log.Printf("Vishal: write args: %#v\n", args)
if code := writeCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
args = []string{
"-address", addr,
"ssh/roles/" + testRoleName,
"key=" + testKey,
"admin_user=" + testUserName,
"cidr=" + testCidr,
"port=" + testPort,
}
log.Printf("Vishal: write args: %#v\n", args)
if code := writeCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
log.Printf("Vishal: Reached here\n")
sshCmd := &SSHCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args = []string{
"-address", addr,
"-role=" + testRoleName,
testUserName + "@" + testIP,
"/usr/bin/whoami",
}
log.Printf("Vishal: ssh args: %#v\n", args)
if code := sshCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
log.Printf("addr:%s testRoleName:%s testUserName:%s testIP:%s testPort:%s\n", addr, testRoleName, testUserName, testIP, testPort)
// TODO: Compare the testUserName and response of whoami should match! else fail test.
}
func executeCommand(ch ssh.Channel, req *ssh.Request) {
command := string(req.Payload[4:])
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
req.Reply(true, nil)
cmd.Stdout = ch
cmd.Stderr = ch
cmd.Stdin = ch
err := cmd.Start()
if err != nil {
panic(fmt.Sprintf("Error starting the command: '%s'", err))
}
go func() {
_, err := cmd.Process.Wait()
if err != nil {
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
}
ch.Close()
}()
}
func startTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
return "", fmt.Errorf("Error parsing public key")
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
return &ssh.Permissions{}, nil
} else {
return nil, fmt.Errorf("Key does not match")
}
},
}
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
if err != nil {
panic("Error parsing private key")
}
serverConfig.AddHostKey(signer)
soc, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("Error listening to connection")
}
go func() {
for {
conn, err := soc.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
}
defer conn.Close()
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
panic(fmt.Sprintf("Handshaking error: %v", err))
}
go func() {
for chanReq := range chanReqs {
go func(chanReq ssh.NewChannel) {
if chanReq.ChannelType() != "session" {
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
return
}
ch, requests, err := chanReq.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting channel: %s", err))
}
go func(ch ssh.Channel, in <-chan *ssh.Request) {
for req := range in {
executeCommand(ch, req)
}
}(ch, requests)
}(chanReq)
}
sshConn.Close()
}()
}
}()
return soc.Addr().String(), nil
}

View File

@ -1,6 +1,7 @@
package vault
import (
"fmt"
"testing"
"github.com/hashicorp/vault/audit"
@ -26,12 +27,19 @@ func TestCore(t *testing.T) *Core {
noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) {
return new(rawHTTP), nil
}
logicalBackends := make(map[string]logical.Factory)
for backendName, backendFactory := range noopBackends {
logicalBackends[backendName] = backendFactory
}
for backendName, backendFactory := range testLogicalBackends {
logicalBackends[backendName] = backendFactory
}
physicalBackend := physical.NewInmem()
c, err := NewCore(&CoreConfig{
Physical: physicalBackend,
AuditBackends: noopAudits,
LogicalBackends: noopBackends,
LogicalBackends: logicalBackends,
CredentialBackends: noopBackends,
DisableMlock: true,
})
@ -83,6 +91,21 @@ func TestKeyCopy(key []byte) []byte {
return result
}
var testLogicalBackends = map[string]logical.Factory{}
// This adds a logical backend for the test core. This needs to be
// invoked before the test core is created.
func AddTestLogicalBackend(name string, factory logical.Factory) error {
if name == "" {
return fmt.Errorf("Missing backend name")
}
if factory == nil {
return fmt.Errorf("Missing backend factory function")
}
testLogicalBackends[name] = factory
return nil
}
type noopAudit struct{}
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {