open-vault/builtin/logical/ssh/util.go

225 lines
6.3 KiB
Go
Raw Normal View History

package ssh
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
"strings"
"time"
"github.com/hashicorp/vault/logical"
2016-08-19 20:45:17 +00:00
log "github.com/mgutz/logxi/v1"
"golang.org/x/crypto/ssh"
)
// Creates a new RSA key pair with the given key length. The private key will be
// of pem format and the public key will be of OpenSSH format.
2015-07-29 18:21:36 +00:00
func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
}
privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}))
sshPublicKey, err := ssh.NewPublicKey(privateKey.Public())
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
}
publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
return
}
// Public key and the script to install the key are uploaded to remote machine.
// Public key is either added or removed from authorized_keys file using the
// script. Default script is for a Linux machine and hence the path of the
// authorized_keys file is hard coded to resemble Linux.
//
// The last param 'install' if false, uninstalls the key.
func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port int, hostkey, dynamicPublicKey, installScript string, install bool) error {
// Transfer the newly generated public key to remote host under a random
// file name. This is to avoid name collisions from other requests.
_, publicKeyFileName, err := b.GenerateSaltedOTP()
if err != nil {
return err
}
2016-08-19 20:45:17 +00:00
comm, err := createSSHComm(b.Logger(), adminUser, ip, port, hostkey)
2016-01-19 06:59:08 +00:00
if err != nil {
return err
}
defer comm.Close()
err = comm.Upload(publicKeyFileName, bytes.NewBufferString(dynamicPublicKey), nil)
if err != nil {
return fmt.Errorf("error uploading public key: %v", err)
}
// Transfer the script required to install or uninstall the key to the remote
// host under a random file name as well. This is to avoid name collisions
// from other requests.
scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName)
2016-01-19 06:59:08 +00:00
err = comm.Upload(scriptFileName, bytes.NewBufferString(installScript), nil)
if err != nil {
return fmt.Errorf("error uploading install script: %v", err)
}
// Create a session to run remote command that triggers the script to install
// or uninstall the key.
2016-01-19 06:59:08 +00:00
session, err := comm.NewSession()
2015-07-02 21:23:09 +00:00
if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %v", err)
2015-07-02 21:23:09 +00:00
}
if session == nil {
return fmt.Errorf("invalid session object")
}
defer session.Close()
2015-07-10 22:18:02 +00:00
authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
2015-07-02 21:23:09 +00:00
var installOption string
if install {
installOption = "install"
} else {
installOption = "uninstall"
2015-07-02 21:23:09 +00:00
}
// Give execute permissions to install script, run and delete it.
chmodCmd := fmt.Sprintf("chmod +x %s", scriptFileName)
scriptCmd := fmt.Sprintf("./%s %s %s %s", scriptFileName, installOption, publicKeyFileName, authKeysFileName)
rmCmd := fmt.Sprintf("rm -f %s", scriptFileName)
targetCmd := fmt.Sprintf("%s;%s;%s", chmodCmd, scriptCmd, rmCmd)
session.Run(targetCmd)
2015-07-02 21:23:09 +00:00
return nil
}
2015-07-27 20:42:03 +00:00
// Takes an IP address and role name and checks if the IP is part
// of CIDR blocks belonging to the role.
func roleContainsIP(ctx context.Context, s logical.Storage, roleName string, ip string) (bool, error) {
if roleName == "" {
return false, fmt.Errorf("missing role name")
}
2015-07-27 20:42:03 +00:00
if ip == "" {
return false, fmt.Errorf("missing ip")
}
2015-07-27 20:42:03 +00:00
roleEntry, err := s.Get(ctx, fmt.Sprintf("roles/%s", roleName))
if err != nil {
return false, fmt.Errorf("error retrieving role %v", err)
}
if roleEntry == nil {
return false, fmt.Errorf("role %q not found", roleName)
}
2015-07-27 20:42:03 +00:00
var role sshRole
if err := roleEntry.DecodeJSON(&role); err != nil {
return false, fmt.Errorf("error decoding role %q", roleName)
}
2015-07-02 21:23:09 +00:00
if matched, err := cidrListContainsIP(ip, role.CIDRList); err != nil {
2015-07-02 21:23:09 +00:00
return false, err
} else {
return matched, nil
}
}
2015-07-27 20:42:03 +00:00
// Returns true if the IP supplied by the user is part of the comma
// separated CIDR blocks
func cidrListContainsIP(ip, cidrList string) (bool, error) {
if len(cidrList) == 0 {
return false, fmt.Errorf("IP does not belong to role")
}
2015-08-13 15:46:55 +00:00
for _, item := range strings.Split(cidrList, ",") {
_, cidrIPNet, err := net.ParseCIDR(item)
if err != nil {
return false, fmt.Errorf("invalid CIDR entry %q", item)
}
2015-07-02 21:23:09 +00:00
if cidrIPNet.Contains(net.ParseIP(ip)) {
return true, nil
}
}
2015-07-02 21:23:09 +00:00
return false, nil
}
2016-08-19 20:45:17 +00:00
func createSSHComm(logger log.Logger, username, ip string, port int, hostkey string) (*comm, error) {
signer, err := ssh.ParsePrivateKey([]byte(hostkey))
2016-01-19 06:59:08 +00:00
if err != nil {
return nil, err
}
clientConfig := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
connfunc := func() (net.Conn, error) {
2015-08-13 15:46:55 +00:00
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", ip, port), 15*time.Second)
if err != nil {
return nil, err
}
if tcpConn, ok := c.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(5 * time.Second)
}
return c, nil
}
2015-08-13 15:46:55 +00:00
config := &SSHCommConfig{
SSHConfig: clientConfig,
Connection: connfunc,
Pty: false,
DisableAgent: true,
2016-08-19 20:45:17 +00:00
Logger: logger,
}
2016-01-19 06:59:08 +00:00
return SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
}
2016-12-26 14:03:27 +00:00
func parsePublicSSHKey(key string) (ssh.PublicKey, error) {
keyParts := strings.Split(key, " ")
if len(keyParts) > 1 {
// Someone has sent the 'full' public key rather than just the base64 encoded part that the ssh library wants
key = keyParts[1]
}
decodedKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
return ssh.ParsePublicKey([]byte(decodedKey))
}
2016-12-26 14:03:27 +00:00
func convertMapToStringValue(initial map[string]interface{}) map[string]string {
result := map[string]string{}
for key, value := range initial {
result[key] = fmt.Sprintf("%v", value)
}
return result
}
// Serve a template processor for custom format inputs
func substQuery(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
}
return tpl
}