open-vault/builtin/logical/ssh/util.go

198 lines
5.5 KiB
Go

package ssh
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
"strings"
"time"
"github.com/hashicorp/vault/logical"
"golang.org/x/crypto/ssh"
)
// Creates a SSH session object which can be used to run commands
// in the target machine. The session will use public key authentication
// method with port 22.
func createSSHPublicKeysSession(username, ipAddr string, port int, hostKey string) (*ssh.Session, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
if ipAddr == "" {
return nil, fmt.Errorf("missing ip address")
}
if hostKey == "" {
return nil, fmt.Errorf("missing host key")
}
signer, err := ssh.ParsePrivateKey([]byte(hostKey))
if err != nil {
return nil, fmt.Errorf("parsing Private Key failed: %s", err)
}
config := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
}
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ipAddr, port), config)
if err != nil {
return nil, err
}
if client == nil {
return nil, fmt.Errorf("invalid client object: %s", err)
}
session, err := client.NewSession()
if err != nil {
return nil, err
}
return session, nil
}
// Creates a new RSA key pair with key length of 2048.
// The private key will be of pem format and the public key will be
// of OpenSSH format.
func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
}
privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}))
sshPublicKey, err := ssh.NewPublicKey(privateKey.Public())
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
}
publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
return
}
// Installs or uninstalls the dynamic key in the remote host. The parameterized script
// will install or uninstall the key. The remote host is assumed to be Linux,
// and hence the path of the authorized_keys file is hard coded to resemble Linux.
// Installing and uninstalling the keys means that the public key is appended or
// removed from authorized_keys file.
// The param 'install' if false, uninstalls the key.
func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostkey string, install bool) error {
session, err := createSSHPublicKeysSession(adminUser, ip, port, hostkey)
if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
}
if session == nil {
return fmt.Errorf("invalid session object")
}
defer session.Close()
authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName)
var installOption string
if install {
installOption = "install"
} else {
installOption = "uninstall"
}
// Give execute permissions to install script, run and delete it.
chmodCmd := fmt.Sprintf("chmod +x %s", scriptFileName)
scriptCmd := fmt.Sprintf("./%s %s %s %s", scriptFileName, installOption, publicKeyFileName, authKeysFileName)
rmCmd := fmt.Sprintf("rm -f %s", scriptFileName)
targetCmd := fmt.Sprintf("%s;%s;%s", chmodCmd, scriptCmd, rmCmd)
session.Run(targetCmd)
return nil
}
// Takes an IP address and role name and checks if the IP is part
// of CIDR blocks belonging to the role.
func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error) {
if roleName == "" {
return false, fmt.Errorf("missing role name")
}
if ip == "" {
return false, fmt.Errorf("missing ip")
}
roleEntry, err := s.Get(fmt.Sprintf("roles/%s", roleName))
if err != nil {
return false, fmt.Errorf("error retrieving role '%s'", err)
}
if roleEntry == nil {
return false, fmt.Errorf("role '%s' not found", roleName)
}
var role sshRole
if err := roleEntry.DecodeJSON(&role); err != nil {
return false, fmt.Errorf("error decoding role '%s'", roleName)
}
if matched, err := cidrContainsIP(ip, role.CIDRList); err != nil {
return false, err
} else {
return matched, nil
}
}
// Returns true if the IP supplied by the user is part of the comma
// separated CIDR blocks
func cidrContainsIP(ip, cidrList string) (bool, error) {
for _, item := range strings.Split(cidrList, ",") {
_, cidrIPNet, err := net.ParseCIDR(item)
if err != nil {
return false, fmt.Errorf("invalid CIDR entry '%s'", item)
}
if cidrIPNet.Contains(net.ParseIP(ip)) {
return true, nil
}
}
return false, nil
}
func scpUpload(username, ip string, port int, hostkey, fileName, fileContent string) error {
signer, err := ssh.ParsePrivateKey([]byte(hostkey))
clientConfig := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
}
connfunc := func() (net.Conn, error) {
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", ip, port), 15*time.Second)
if err != nil {
return nil, err
}
if tcpConn, ok := c.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(5 * time.Second)
}
return c, nil
}
config := &SSHCommConfig{
SSHConfig: clientConfig,
Connection: connfunc,
Pty: false,
DisableAgent: true,
}
comm, err := SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
if err != nil {
return fmt.Errorf("error connecting to target: %s", err)
}
comm.Upload(fileName, bytes.NewBufferString(fileContent), nil)
return nil
}