Vault SSH: Test case skeleton
This commit is contained in:
parent
3e4b67f90a
commit
3c7dd8611c
|
@ -0,0 +1,227 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
testCidr = "127.0.0.1/32"
|
||||
testRoleName = "testRoleName"
|
||||
testKey = "testKey"
|
||||
testPublicKey = `
|
||||
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
|
||||
`
|
||||
testPrivateKey = `
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
|
||||
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
|
||||
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
|
||||
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
|
||||
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
|
||||
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
|
||||
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
|
||||
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
|
||||
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
|
||||
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
|
||||
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
|
||||
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
|
||||
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
|
||||
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
|
||||
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
|
||||
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
|
||||
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
|
||||
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
|
||||
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
|
||||
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
|
||||
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
|
||||
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
|
||||
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
|
||||
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
|
||||
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
)
|
||||
|
||||
var testIP string
|
||||
var testPort string
|
||||
var testUserName string
|
||||
var testAdminUser string
|
||||
|
||||
func init() {
|
||||
addr, err := startTestServer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error starting mock server:%s", err))
|
||||
}
|
||||
input := strings.Split(addr, ":")
|
||||
testIP = input[0]
|
||||
testPort = input[1]
|
||||
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error getting current username: '%s'", err))
|
||||
}
|
||||
testUserName = u.Username
|
||||
testAdminUser = u.Username
|
||||
}
|
||||
|
||||
func TestSSHBackend(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testNamedKeys(t),
|
||||
testNewRole(t),
|
||||
testRoleCreate(t),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func startTestServer() (string, error) {
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error parsing public key")
|
||||
}
|
||||
serverConfig := &ssh.ServerConfig{
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
|
||||
return &ssh.Permissions{}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Key does not match")
|
||||
}
|
||||
},
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
|
||||
if err != nil {
|
||||
panic("Error parsing private key")
|
||||
}
|
||||
serverConfig.AddHostKey(signer)
|
||||
|
||||
soc, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error listening to connection")
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := soc.Accept()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
|
||||
}
|
||||
defer conn.Close()
|
||||
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Handshaking error: %v", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
for chanReq := range chanReqs {
|
||||
go func(chanReq ssh.NewChannel) {
|
||||
if chanReq.ChannelType() != "session" {
|
||||
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
return
|
||||
}
|
||||
|
||||
ch, requests, err := chanReq.Accept()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error accepting channel: %s", err))
|
||||
}
|
||||
|
||||
go func(ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
executeCommand(ch, req)
|
||||
}
|
||||
}(ch, requests)
|
||||
}(chanReq)
|
||||
}
|
||||
sshConn.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return soc.Addr().String(), nil
|
||||
}
|
||||
|
||||
func executeCommand(ch ssh.Channel, req *ssh.Request) {
|
||||
command := string(req.Payload[4:])
|
||||
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
|
||||
req.Reply(true, nil)
|
||||
|
||||
cmd.Stdout = ch
|
||||
cmd.Stderr = ch
|
||||
cmd.Stdin = ch
|
||||
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error starting the command: '%s'", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := cmd.Process.Wait()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
|
||||
}
|
||||
ch.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
func testRoleCreate(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: fmt.Sprintf("creds/%s", testRoleName),
|
||||
Data: map[string]interface{}{
|
||||
"username": testUserName,
|
||||
"ip": testIP,
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Key string `mapstructure:"key"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] Generated Key:%s\n", d.Key)
|
||||
if d.Key == "" {
|
||||
return fmt.Errorf("Generated key is an empty string")
|
||||
}
|
||||
_, err := ssh.ParsePrivateKey([]byte(d.Key))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Generated key is invalid")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testNewRole(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: fmt.Sprintf("roles/%s", testRoleName),
|
||||
Data: map[string]interface{}{
|
||||
"key": testKey,
|
||||
"admin_user": testAdminUser,
|
||||
"cidr": testCidr,
|
||||
"port": testPort,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testNamedKeys(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: fmt.Sprintf("keys/%s", testKey),
|
||||
Data: map[string]interface{}{
|
||||
"key": testPrivateKey,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package ssh
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -35,6 +36,7 @@ func pathRoleCreate(b *backend) *framework.Path {
|
|||
|
||||
func (b *backend) pathRoleCreateWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
log.Printf("Vishal: pathRoleCreateWrite\n")
|
||||
roleName := d.Get("name").(string)
|
||||
username := d.Get("username").(string)
|
||||
ipRaw := d.Get("ip").(string)
|
||||
|
@ -92,9 +94,11 @@ func (b *backend) pathRoleCreateWrite(
|
|||
|
||||
// Transfer the public key to target machine
|
||||
err = uploadPublicKeyScp(dynamicPublicKey, username, ip, role.Port, hostKey.Key)
|
||||
//return nil, nil //TODO remove this
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("Vishal: uploaded public key file to target\n")
|
||||
|
||||
// Add the public key to authorized_keys file in target machine
|
||||
err = installPublicKeyInTarget(username, ip, role.Port, hostKey.Key)
|
||||
|
@ -102,6 +106,7 @@ func (b *backend) pathRoleCreateWrite(
|
|||
return nil, fmt.Errorf("error adding public key to authorized_keys file in target")
|
||||
}
|
||||
|
||||
log.Printf("Vishal: installed public key file to target\n")
|
||||
result := b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{
|
||||
"key": dynamicPrivateKey,
|
||||
}, map[string]interface{}{
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
|
@ -36,9 +37,9 @@ func uploadPublicKeyScp(publicKey, username, ip, port, key string) error {
|
|||
fmt.Fprint(w, "\x00")
|
||||
w.Close()
|
||||
}()
|
||||
if err := session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Vishal: uploading now\n")
|
||||
err = session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName))
|
||||
log.Printf("Vishal: upload completed: err:%s\n", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -113,22 +114,22 @@ func installPublicKeyInTarget(username, ip, port, hostKey string) error {
|
|||
}
|
||||
defer session.Close()
|
||||
|
||||
authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
|
||||
tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username)
|
||||
authKeysFileName := "~/.ssh/authorized_keys"
|
||||
tempKeysFileName := "~/temp_authorized_keys"
|
||||
|
||||
// Commands to be run on target machine
|
||||
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
|
||||
grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName)
|
||||
catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName)
|
||||
catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName)
|
||||
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
|
||||
//removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
|
||||
log.Printf(grepCmd)
|
||||
log.Printf(catCmdRemoveDuplicate)
|
||||
log.Printf(catCmdAppendNew)
|
||||
|
||||
targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
|
||||
|
||||
// Run the commands on target machine
|
||||
if err := session.Run(targetCmd); err != nil {
|
||||
return err
|
||||
}
|
||||
//targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd)
|
||||
targetCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew)
|
||||
session.Run(targetCmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -143,8 +144,8 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
|
|||
}
|
||||
defer session.Close()
|
||||
|
||||
authKeysFileName := "/home/" + username + "/.ssh/authorized_keys"
|
||||
tempKeysFileName := "/home/" + username + "/temp_authorized_keys"
|
||||
authKeysFileName := "~/.ssh/authorized_keys"
|
||||
tempKeysFileName := "~/temp_authorized_keys"
|
||||
|
||||
// Commands to be run on target machine
|
||||
dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip)
|
||||
|
@ -153,11 +154,7 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error {
|
|||
removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName)
|
||||
|
||||
remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd)
|
||||
|
||||
// Run the commands in target machine
|
||||
if err := session.Run(remoteCmd); err != nil {
|
||||
return err
|
||||
}
|
||||
session.Run(remoteCmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
c.Ui.Error(fmt.Sprintf("Error setting default role: %s", err.Error()))
|
||||
return 1
|
||||
}
|
||||
c.Ui.Output(fmt.Sprintf("Using role[%s]\n", role))
|
||||
c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role))
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
|
@ -72,10 +72,14 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
sshDynamicKeyFileName := fmt.Sprintf("vault_temp_file_%s_%s", username, ip.String())
|
||||
err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(sshDynamicKey), 0600)
|
||||
|
||||
cmd := exec.Command("ssh", "-p", port, "-i", sshDynamicKeyFileName, args[0])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
err = cmd.Run()
|
||||
sshCmdArgs := []string{"-p", port, "-i", sshDynamicKeyFileName}
|
||||
sshCmdArgs = append(sshCmdArgs, args...)
|
||||
|
||||
sshCmd := exec.Command("ssh", sshCmdArgs...)
|
||||
sshCmd.Stdin = os.Stdin
|
||||
sshCmd.Stdout = os.Stdout
|
||||
|
||||
err = sshCmd.Run()
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err))
|
||||
}
|
||||
|
@ -145,6 +149,8 @@ SSH Options:
|
|||
skip mentioning the role. It will be chosen by default.
|
||||
If there are no roless associated with the IP, register
|
||||
the CIDR block of that IP using the "roles/" endpoint.
|
||||
|
||||
-port Port number to use for SSH connection. This defaults to port 22.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
logicalssh "github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/mitchellh/cli"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
testCidr = "127.0.0.1/32"
|
||||
testRoleName = "testRoleName"
|
||||
testKey = "testKey"
|
||||
testPublicKey = `
|
||||
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv
|
||||
`
|
||||
testPrivateKey = `
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn
|
||||
Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k
|
||||
yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM
|
||||
PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY
|
||||
rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC
|
||||
bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH
|
||||
hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu
|
||||
c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa
|
||||
0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq
|
||||
d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/
|
||||
fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm
|
||||
U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2
|
||||
V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg
|
||||
j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu
|
||||
+fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu
|
||||
vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw
|
||||
eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96
|
||||
u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c
|
||||
+LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW
|
||||
U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi
|
||||
fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8
|
||||
Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw
|
||||
Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3
|
||||
7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1
|
||||
eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
)
|
||||
|
||||
var testIP string
|
||||
var testPort string
|
||||
var testUserName string
|
||||
var testAdminUser string
|
||||
|
||||
func init() {
|
||||
addr, err := startTestServer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error starting mock server:%s", err))
|
||||
}
|
||||
input := strings.Split(addr, ":")
|
||||
testIP = input[0]
|
||||
testPort = input[1]
|
||||
//testPort = "22"
|
||||
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error getting current username: '%s'", err))
|
||||
}
|
||||
testUserName = u.Username
|
||||
testAdminUser = u.Username
|
||||
//testUserName = "vishal" //TODO: remove this
|
||||
//testAdminUser = "vishal" //TODO: remove this
|
||||
}
|
||||
|
||||
func TestSSH(t *testing.T) {
|
||||
err := vault.AddTestLogicalBackend("ssh", logicalssh.Factory)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
mountCmd := &MountCommand{
|
||||
Meta: Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{"-address", addr, "ssh"}
|
||||
log.Printf("Vishal: mount args: %#v\n", args)
|
||||
|
||||
if code := mountCmd.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
|
||||
client, err := mountCmd.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
mounts, err := client.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
mount, ok := mounts["ssh/"]
|
||||
if !ok {
|
||||
t.Fatal("should have ssh mount")
|
||||
}
|
||||
if mount.Type != "ssh" {
|
||||
t.Fatal("should have ssh type")
|
||||
}
|
||||
writeCmd := &WriteCommand{
|
||||
Meta: Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
args = []string{
|
||||
"-address", addr,
|
||||
"ssh/keys/" + testKey,
|
||||
"key=" + testPrivateKey,
|
||||
}
|
||||
log.Printf("Vishal: write args: %#v\n", args)
|
||||
if code := writeCmd.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
|
||||
args = []string{
|
||||
"-address", addr,
|
||||
"ssh/roles/" + testRoleName,
|
||||
"key=" + testKey,
|
||||
"admin_user=" + testUserName,
|
||||
"cidr=" + testCidr,
|
||||
"port=" + testPort,
|
||||
}
|
||||
log.Printf("Vishal: write args: %#v\n", args)
|
||||
if code := writeCmd.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
log.Printf("Vishal: Reached here\n")
|
||||
|
||||
sshCmd := &SSHCommand{
|
||||
Meta: Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
args = []string{
|
||||
"-address", addr,
|
||||
"-role=" + testRoleName,
|
||||
testUserName + "@" + testIP,
|
||||
"/usr/bin/whoami",
|
||||
}
|
||||
log.Printf("Vishal: ssh args: %#v\n", args)
|
||||
if code := sshCmd.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
log.Printf("addr:%s testRoleName:%s testUserName:%s testIP:%s testPort:%s\n", addr, testRoleName, testUserName, testIP, testPort)
|
||||
// TODO: Compare the testUserName and response of whoami should match! else fail test.
|
||||
}
|
||||
|
||||
func executeCommand(ch ssh.Channel, req *ssh.Request) {
|
||||
command := string(req.Payload[4:])
|
||||
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
|
||||
req.Reply(true, nil)
|
||||
|
||||
cmd.Stdout = ch
|
||||
cmd.Stderr = ch
|
||||
cmd.Stdin = ch
|
||||
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error starting the command: '%s'", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := cmd.Process.Wait()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
|
||||
}
|
||||
ch.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
func startTestServer() (string, error) {
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error parsing public key")
|
||||
}
|
||||
serverConfig := &ssh.ServerConfig{
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
|
||||
return &ssh.Permissions{}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Key does not match")
|
||||
}
|
||||
},
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
|
||||
if err != nil {
|
||||
panic("Error parsing private key")
|
||||
}
|
||||
serverConfig.AddHostKey(signer)
|
||||
|
||||
soc, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error listening to connection")
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := soc.Accept()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
|
||||
}
|
||||
defer conn.Close()
|
||||
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Handshaking error: %v", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
for chanReq := range chanReqs {
|
||||
go func(chanReq ssh.NewChannel) {
|
||||
if chanReq.ChannelType() != "session" {
|
||||
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
return
|
||||
}
|
||||
|
||||
ch, requests, err := chanReq.Accept()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error accepting channel: %s", err))
|
||||
}
|
||||
|
||||
go func(ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
executeCommand(ch, req)
|
||||
}
|
||||
}(ch, requests)
|
||||
}(chanReq)
|
||||
}
|
||||
sshConn.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return soc.Addr().String(), nil
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/audit"
|
||||
|
@ -26,12 +27,19 @@ func TestCore(t *testing.T) *Core {
|
|||
noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return new(rawHTTP), nil
|
||||
}
|
||||
logicalBackends := make(map[string]logical.Factory)
|
||||
for backendName, backendFactory := range noopBackends {
|
||||
logicalBackends[backendName] = backendFactory
|
||||
}
|
||||
for backendName, backendFactory := range testLogicalBackends {
|
||||
logicalBackends[backendName] = backendFactory
|
||||
}
|
||||
|
||||
physicalBackend := physical.NewInmem()
|
||||
c, err := NewCore(&CoreConfig{
|
||||
Physical: physicalBackend,
|
||||
AuditBackends: noopAudits,
|
||||
LogicalBackends: noopBackends,
|
||||
LogicalBackends: logicalBackends,
|
||||
CredentialBackends: noopBackends,
|
||||
DisableMlock: true,
|
||||
})
|
||||
|
@ -83,6 +91,21 @@ func TestKeyCopy(key []byte) []byte {
|
|||
return result
|
||||
}
|
||||
|
||||
var testLogicalBackends = map[string]logical.Factory{}
|
||||
|
||||
// This adds a logical backend for the test core. This needs to be
|
||||
// invoked before the test core is created.
|
||||
func AddTestLogicalBackend(name string, factory logical.Factory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("Missing backend name")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("Missing backend factory function")
|
||||
}
|
||||
testLogicalBackends[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopAudit struct{}
|
||||
|
||||
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {
|
||||
|
|
Loading…
Reference in New Issue