Vault SSH: Test case skeleton

This commit is contained in:
Vishal Nayak 2015-07-10 09:56:14 -06:00
parent 3e4b67f90a
commit 3c7dd8611c
6 changed files with 543 additions and 26 deletions

View File

@ -0,0 +1,227 @@
package ssh
import (
"bytes"
"fmt"
"log"
"net"
"os/exec"
"os/user"
"strings"
"testing"
"golang.org/x/crypto/ssh"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)
const (
testCidr = "127.0.0.1/32"
testRoleName = "testRoleName"
testKey = "testKey"
testPublicKey = `
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
`
testPrivateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
-----END RSA PRIVATE KEY-----
`
)
var testIP string
var testPort string
var testUserName string
var testAdminUser string
func init() {
addr, err := startTestServer()
if err != nil {
panic(fmt.Sprintf("Error starting mock server:%s", err))
}
input := strings.Split(addr, ":")
testIP = input[0]
testPort = input[1]
u, err := user.Current()
if err != nil {
panic(fmt.Sprintf("Error getting current username: '%s'", err))
}
testUserName = u.Username
testAdminUser = u.Username
}
func TestSSHBackend(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testNamedKeys(t),
testNewRole(t),
testRoleCreate(t),
},
})
}
func startTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
return "", fmt.Errorf("Error parsing public key")
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
return &ssh.Permissions{}, nil
} else {
return nil, fmt.Errorf("Key does not match")
}
},
}
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
if err != nil {
panic("Error parsing private key")
}
serverConfig.AddHostKey(signer)
soc, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("Error listening to connection")
}
go func() {
for {
conn, err := soc.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
}
defer conn.Close()
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
panic(fmt.Sprintf("Handshaking error: %v", err))
}
go func() {
for chanReq := range chanReqs {
go func(chanReq ssh.NewChannel) {
if chanReq.ChannelType() != "session" {
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
return
}
ch, requests, err := chanReq.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting channel: %s", err))
}
go func(ch ssh.Channel, in <-chan *ssh.Request) {
for req := range in {
executeCommand(ch, req)
}
}(ch, requests)
}(chanReq)
}
sshConn.Close()
}()
}
}()
return soc.Addr().String(), nil
}
func executeCommand(ch ssh.Channel, req *ssh.Request) {
command := string(req.Payload[4:])
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
req.Reply(true, nil)
cmd.Stdout = ch
cmd.Stderr = ch
cmd.Stdin = ch
err := cmd.Start()
if err != nil {
panic(fmt.Sprintf("Error starting the command: '%s'", err))
}
go func() {
_, err := cmd.Process.Wait()
if err != nil {
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
}
ch.Close()
}()
}
func testRoleCreate(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("creds/%s", testRoleName),
Data: map[string]interface{}{
"username": testUserName,
"ip": testIP,
},
Check: func(resp *logical.Response) error {
var d struct {
Key string `mapstructure:"key"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated Key:%s\n", d.Key)
if d.Key == "" {
return fmt.Errorf("Generated key is an empty string")
}
_, err := ssh.ParsePrivateKey([]byte(d.Key))
if err != nil {
return fmt.Errorf("Generated key is invalid")
}
return nil
},
}
}
func testNewRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("roles/%s", testRoleName),
Data: map[string]interface{}{
"key": testKey,
"admin_user": testAdminUser,
"cidr": testCidr,
"port": testPort,
},
}
}
func testNamedKeys(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: fmt.Sprintf("keys/%s", testKey),
Data: map[string]interface{}{
"key": testPrivateKey,
},
}
}

View File

@ -2,6 +2,7 @@ package ssh
import (
"fmt"
"log"
"net"
"github.com/hashicorp/vault/logical"
@ -35,6 +36,7 @@ func pathRoleCreate(b *backend) *framework.Path {
func (b *backend) pathRoleCreateWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
log.Printf("Vishal: pathRoleCreateWrite\n")
roleName := d.Get("name").(string)
username := d.Get("username").(string)
ipRaw := d.Get("ip").(string)
@ -92,9 +94,11 @@ func (b *backend) pathRoleCreateWrite(
// Transfer the public key to target machine
err = uploadPublicKeyScp(dynamicPublicKey, username, ip, role.Port, hostKey.Key)
//return nil, nil //TODO remove this
if err != nil {
return nil, err
}
log.Printf("Vishal: uploaded public key file to target\n")
// Add the public key to authorized_keys file in target machine
err = installPublicKeyInTarget(username, ip, role.Port, hostKey.Key)
@ -102,6 +106,7 @@ func (b *backend) pathRoleCreateWrite(
return nil, fmt.Errorf("error adding public key to authorized_keys file in target")
}
log.Printf("Vishal: installed public key file to target\n")
result := b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{
"key": dynamicPrivateKey,
}, map[string]interface{}{

View File

@ -8,6 +8,7 @@ import (
"encoding/pem"
"fmt"
"io"
"log"
"net"
"strings"
@ -36,9 +37,9 @@ func uploadPublicKeyScp(publicKey, username, ip, port, key string) error {
fmt.Fprint(w, "\x00")
w.Close()
}()
if err := session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)); err != nil {
return err
}
log.Printf("Vishal: uploading now\n")
err = session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName))
log.Printf("Vishal: upload completed: err:%s\n", err)
return nil
}
@ -113,22 +114,22 @@ func installPublicKeyInTarget(username, ip, port, hostKey string) error {
}
defer session.Close()
authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username)
authKeysFileName := "~/.ssh/authorized_keys"
tempKeysFileName := "~/temp_authorized_keys"
// Commands to be run on target machine
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName)
catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName)
catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName)
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
//removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
log.Printf(grepCmd)
log.Printf(catCmdRemoveDuplicate)
log.Printf(catCmdAppendNew)
targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
// Run the commands on target machine
if err := session.Run(targetCmd); err != nil {
return err
}
//targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
targetCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew)
session.Run(targetCmd)
return nil
}
@ -143,8 +144,8 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
}
defer session.Close()
authKeysFileName := "/home/" + username + "/.ssh/authorized_keys"
tempKeysFileName := "/home/" + username + "/temp_authorized_keys"
authKeysFileName := "~/.ssh/authorized_keys"
tempKeysFileName := "~/temp_authorized_keys"
// Commands to be run on target machine
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
@ -153,11 +154,7 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd)
// Run the commands in target machine
if err := session.Run(remoteCmd); err != nil {
return err
}
session.Run(remoteCmd)
return nil
}

View File

@ -51,7 +51,7 @@ func (c *SSHCommand) Run(args []string) int {
c.Ui.Error(fmt.Sprintf("Error setting default role: %s", err.Error()))
return 1
}
c.Ui.Output(fmt.Sprintf("Using role[%s]\n", role))
c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role))
}
data := map[string]interface{}{
@ -72,10 +72,14 @@ func (c *SSHCommand) Run(args []string) int {
sshDynamicKeyFileName := fmt.Sprintf("vault_temp_file_%s_%s", username, ip.String())
err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(sshDynamicKey), 0600)
cmd := exec.Command("ssh", "-p", port, "-i", sshDynamicKeyFileName, args[0])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
err = cmd.Run()
sshCmdArgs := []string{"-p", port, "-i", sshDynamicKeyFileName}
sshCmdArgs = append(sshCmdArgs, args...)
sshCmd := exec.Command("ssh", sshCmdArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
err = sshCmd.Run()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err))
}
@ -138,13 +142,15 @@ General Options:
SSH Options:
-role Mention the role to be used to create dynamic key.
-role Mention the role to be used to create dynamic key.
Each IP is associated with a role. To see the associated
roles with IP, use "lookup" endpoint. If you are certain that
there is only one role associated with the IP, you can
skip mentioning the role. It will be chosen by default.
If there are no roless associated with the IP, register
the CIDR block of that IP using the "roles/" endpoint.
-port Port number to use for SSH connection. This defaults to port 22.
`
return strings.TrimSpace(helpText)
}

259
command/ssh_test.go Normal file
View File

@ -0,0 +1,259 @@
package command
import (
"bytes"
"fmt"
"log"
"net"
"os/exec"
"os/user"
"strings"
"testing"
logicalssh "github.com/hashicorp/vault/builtin/logical/ssh"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
"golang.org/x/crypto/ssh"
)
const (
testCidr = "127.0.0.1/32"
testRoleName = "testRoleName"
testKey = "testKey"
testPublicKey = `
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
`
testPrivateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
-----END RSA PRIVATE KEY-----
`
)
var testIP string
var testPort string
var testUserName string
var testAdminUser string
func init() {
addr, err := startTestServer()
if err != nil {
panic(fmt.Sprintf("Error starting mock server:%s", err))
}
input := strings.Split(addr, ":")
testIP = input[0]
testPort = input[1]
//testPort = "22"
u, err := user.Current()
if err != nil {
panic(fmt.Sprintf("Error getting current username: '%s'", err))
}
testUserName = u.Username
testAdminUser = u.Username
//testUserName = "vishal" //TODO: remove this
//testAdminUser = "vishal" //TODO: remove this
}
func TestSSH(t *testing.T) {
err := vault.AddTestLogicalBackend("ssh", logicalssh.Factory)
if err != nil {
t.Fatalf("err: %s", err)
}
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
mountCmd := &MountCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{"-address", addr, "ssh"}
log.Printf("Vishal: mount args: %#v\n", args)
if code := mountCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
client, err := mountCmd.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
mounts, err := client.Sys().ListMounts()
if err != nil {
t.Fatalf("err: %s", err)
}
mount, ok := mounts["ssh/"]
if !ok {
t.Fatal("should have ssh mount")
}
if mount.Type != "ssh" {
t.Fatal("should have ssh type")
}
writeCmd := &WriteCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args = []string{
"-address", addr,
"ssh/keys/" + testKey,
"key=" + testPrivateKey,
}
log.Printf("Vishal: write args: %#v\n", args)
if code := writeCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
args = []string{
"-address", addr,
"ssh/roles/" + testRoleName,
"key=" + testKey,
"admin_user=" + testUserName,
"cidr=" + testCidr,
"port=" + testPort,
}
log.Printf("Vishal: write args: %#v\n", args)
if code := writeCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
log.Printf("Vishal: Reached here\n")
sshCmd := &SSHCommand{
Meta: Meta{
ClientToken: token,
Ui: ui,
},
}
args = []string{
"-address", addr,
"-role=" + testRoleName,
testUserName + "@" + testIP,
"/usr/bin/whoami",
}
log.Printf("Vishal: ssh args: %#v\n", args)
if code := sshCmd.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
log.Printf("addr:%s testRoleName:%s testUserName:%s testIP:%s testPort:%s\n", addr, testRoleName, testUserName, testIP, testPort)
// TODO: Compare the testUserName and response of whoami should match! else fail test.
}
func executeCommand(ch ssh.Channel, req *ssh.Request) {
command := string(req.Payload[4:])
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
req.Reply(true, nil)
cmd.Stdout = ch
cmd.Stderr = ch
cmd.Stdin = ch
err := cmd.Start()
if err != nil {
panic(fmt.Sprintf("Error starting the command: '%s'", err))
}
go func() {
_, err := cmd.Process.Wait()
if err != nil {
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
}
ch.Close()
}()
}
func startTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
return "", fmt.Errorf("Error parsing public key")
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
return &ssh.Permissions{}, nil
} else {
return nil, fmt.Errorf("Key does not match")
}
},
}
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
if err != nil {
panic("Error parsing private key")
}
serverConfig.AddHostKey(signer)
soc, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("Error listening to connection")
}
go func() {
for {
conn, err := soc.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
}
defer conn.Close()
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
panic(fmt.Sprintf("Handshaking error: %v", err))
}
go func() {
for chanReq := range chanReqs {
go func(chanReq ssh.NewChannel) {
if chanReq.ChannelType() != "session" {
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
return
}
ch, requests, err := chanReq.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting channel: %s", err))
}
go func(ch ssh.Channel, in <-chan *ssh.Request) {
for req := range in {
executeCommand(ch, req)
}
}(ch, requests)
}(chanReq)
}
sshConn.Close()
}()
}
}()
return soc.Addr().String(), nil
}

View File

@ -1,6 +1,7 @@
package vault
import (
"fmt"
"testing"
"github.com/hashicorp/vault/audit"
@ -26,12 +27,19 @@ func TestCore(t *testing.T) *Core {
noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) {
return new(rawHTTP), nil
}
logicalBackends := make(map[string]logical.Factory)
for backendName, backendFactory := range noopBackends {
logicalBackends[backendName] = backendFactory
}
for backendName, backendFactory := range testLogicalBackends {
logicalBackends[backendName] = backendFactory
}
physicalBackend := physical.NewInmem()
c, err := NewCore(&CoreConfig{
Physical: physicalBackend,
AuditBackends: noopAudits,
LogicalBackends: noopBackends,
LogicalBackends: logicalBackends,
CredentialBackends: noopBackends,
DisableMlock: true,
})
@ -83,6 +91,21 @@ func TestKeyCopy(key []byte) []byte {
return result
}
var testLogicalBackends = map[string]logical.Factory{}
// This adds a logical backend for the test core. This needs to be
// invoked before the test core is created.
func AddTestLogicalBackend(name string, factory logical.Factory) error {
if name == "" {
return fmt.Errorf("Missing backend name")
}
if factory == nil {
return fmt.Errorf("Missing backend factory function")
}
testLogicalBackends[name] = factory
return nil
}
type noopAudit struct{}
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {