Vault SSH: cidr to cidr_list

This commit is contained in:
vishalnayak 2015-08-13 08:46:55 -07:00
parent 7d3025fd6e
commit 8e946f27cc
6 changed files with 383 additions and 52 deletions

View file

@ -18,7 +18,7 @@ import (
const (
testOTPKeyType = "otp"
testDynamicKeyType = "dynamic"
testCidr = "127.0.0.1/32"
testCIDRList = "127.0.0.1/32"
testDynamicRoleName = "testDynamicRoleName"
testOTPRoleName = "testOTPRoleName"
testKeyName = "testKeyName"
@ -55,7 +55,7 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
var testIP string
var testOTP string
var testPort string
var testPort int
var testUserName string
var testAdminUser string
var testInstallScript string
@ -92,13 +92,13 @@ func TestSSHBackend_Lookup(t *testing.T) {
otpData := map[string]interface{}{
"key_type": testOTPKeyType,
"default_user": testUserName,
"cidr": testCidr,
"cidr_list": testCIDRList,
}
dynamicData := map[string]interface{}{
"key_type": testDynamicKeyType,
"key": testKeyName,
"admin_user": testAdminUser,
"cidr": testCidr,
"cidr_list": testCIDRList,
"install_script": testInstallScript,
}
logicaltest.Test(t, logicaltest.TestCase{
@ -133,7 +133,7 @@ func TestSSHBackend_OTPRoleCrud(t *testing.T) {
data := map[string]interface{}{
"key_type": testOTPKeyType,
"default_user": testUserName,
"cidr": testCidr,
"cidr_list": testCIDRList,
}
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
@ -151,7 +151,7 @@ func TestSSHBackend_DynamicRoleCrud(t *testing.T) {
"key_type": testDynamicKeyType,
"key": testKeyName,
"admin_user": testAdminUser,
"cidr": testCidr,
"cidr_list": testCIDRList,
"install_script": testInstallScript,
}
logicaltest.Test(t, logicaltest.TestCase{
@ -182,7 +182,7 @@ func TestSSHBackend_OTPCreate(t *testing.T) {
data := map[string]interface{}{
"key_type": testOTPKeyType,
"default_user": testUserName,
"cidr": testCidr,
"cidr_list": testCIDRList,
}
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
@ -303,11 +303,11 @@ func testRoleRead(t *testing.T, name string, data map[string]interface{}) logica
return fmt.Errorf("error decoding response:%s", err)
}
if name == testOTPRoleName {
if d.KeyType != data["key_type"] || d.DefaultUser != data["default_user"] || d.CIDR != data["cidr"] {
if d.KeyType != data["key_type"] || d.DefaultUser != data["default_user"] || d.CIDRList != data["cidr_list"] {
return fmt.Errorf("data mismatch. bad: %#v", resp)
}
} else {
if d.AdminUser != data["admin_user"] || d.CIDR != data["cidr"] || d.KeyName != data["key"] || d.KeyType != data["key_type"] {
if d.AdminUser != data["admin_user"] || d.CIDRList != data["cidr_list"] || d.KeyName != data["key"] || d.KeyType != data["key_type"] {
return fmt.Errorf("data mismatch. bad: %#v", resp)
}
}
@ -331,7 +331,7 @@ func testNewDynamicKeyRole(t *testing.T) logicaltest.TestStep {
"key_type": "dynamic",
"key": testKeyName,
"admin_user": testAdminUser,
"cidr": testCidr,
"cidr_list": testCIDRList,
"port": testPort,
"install_script": testInstallScript,
},

View file

@ -0,0 +1,338 @@
package ssh
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
type comm struct {
client *ssh.Client
config *SSHCommConfig
conn net.Conn
address string
}
// SSHCommConfig is the structure used to configure the SSH communicator.
type SSHCommConfig struct {
// The configuration of the Go SSH connection
SSHConfig *ssh.ClientConfig
// Connection returns a new connection. The current connection
// in use will be closed as part of the Close method, or in the
// case an error occurs.
Connection func() (net.Conn, error)
// Pty, if true, will request a pty from the remote end.
Pty bool
// DisableAgent, if true, will not forward the SSH agent.
DisableAgent bool
}
// Creates a new communicator implementation over SSH. This takes
// an already existing TCP connection and SSH configuration.
func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error) {
// Establish an initial connection and connect
result = &comm{
config: config,
address: address,
}
if err = result.reconnect(); err != nil {
result = nil
return
}
return
}
func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
// The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path)
target_file := filepath.Base(path)
// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
// This does not work when the target host is unix. Switch to forward slash
// which works for unix and windows
target_dir = filepath.ToSlash(target_dir)
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
return scpUploadFile(target_file, input, w, stdoutR, fi)
}
return c.scpSession("scp -vt "+target_dir, scpFunc)
}
func (c *comm) newSession() (session *ssh.Session, err error) {
if c.client == nil {
err = errors.New("client not available")
} else {
session, err = c.client.NewSession()
}
if err != nil {
log.Printf("ssh session open error: '%s', attempting reconnect", err)
if err := c.reconnect(); err != nil {
return nil, err
}
return c.client.NewSession()
}
return session, nil
}
func (c *comm) reconnect() (err error) {
if c.conn != nil {
c.conn.Close()
}
// Set the conn and client to nil since we'll recreate it
c.conn = nil
c.client = nil
c.conn, err = c.config.Connection()
if err != nil {
// Explicitly set this to the REAL nil. Connection() can return
// a nil implementation of net.Conn which will make the
// "if c.conn == nil" check fail above. Read here for more information
// on this psychotic language feature:
//
// http://golang.org/doc/faq#nil_error
c.conn = nil
log.Printf("reconnection error: %s", err)
return
}
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
if err != nil {
log.Printf("handshake error: %s", err)
}
if sshConn != nil {
c.client = ssh.NewClient(sshConn, sshChan, req)
}
c.connectToAgent()
return
}
func (c *comm) connectToAgent() {
if c.client == nil {
return
}
if c.config.DisableAgent {
return
}
// open connection to the local agent
socketLocation := os.Getenv("SSH_AUTH_SOCK")
if socketLocation == "" {
return
}
agentConn, err := net.Dial("unix", socketLocation)
if err != nil {
log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
return
}
// create agent and add in auth
forwardingAgent := agent.NewClient(agentConn)
if forwardingAgent == nil {
log.Printf("[ERROR] Could not create agent client")
agentConn.Close()
return
}
// add callback for forwarding agent to SSH config
// XXX - might want to handle reconnects appending multiple callbacks
auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
agent.ForwardToAgent(c.client, forwardingAgent)
// Setup a session to request agent forwarding
session, err := c.newSession()
if err != nil {
return
}
defer session.Close()
err = agent.RequestAgentForwarding(session)
if err != nil {
log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
return
}
return
}
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
session, err := c.newSession()
if err != nil {
return err
}
defer session.Close()
// Get a pipe to stdin so that we can send data down
stdinW, err := session.StdinPipe()
if err != nil {
return err
}
// We only want to close once, so we nil w after we close it,
// and only close in the defer if it hasn't been closed already.
defer func() {
if stdinW != nil {
stdinW.Close()
}
}()
// Get a pipe to stdout so that we can get responses back
stdoutPipe, err := session.StdoutPipe()
if err != nil {
return err
}
stdoutR := bufio.NewReader(stdoutPipe)
// Set stderr to a bytes buffer
stderr := new(bytes.Buffer)
session.Stderr = stderr
// Start the sink mode on the other side
if err := session.Start(scpCommand); err != nil {
return err
}
// Call our callback that executes in the context of SCP. We ignore
// EOF errors if they occur because it usually means that SCP prematurely
// ended on the other side.
if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
return err
}
// Close the stdin, which sends an EOF, and then set w to nil so that
// our defer func doesn't close it again since that is unsafe with
// the Go SSH package.
stdinW.Close()
stdinW = nil
// Wait for the SCP connection to close, meaning it has consumed all
// our data and has completed. Or has errored.
err = session.Wait()
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
// Otherwise, we have an ExitErorr, meaning we can just read
// the exit status
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
// If we exited with status 127, it means SCP isn't available.
// Return a more descriptive error for that.
if exitErr.ExitStatus() == 127 {
return errors.New(
"SCP failed to start. This usually means that SCP is not\n" +
"properly installed on the remote system.")
}
}
return err
}
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
return nil
}
// checkSCPStatus checks that a prior command sent to SCP completed
// successfully. If it did not complete successfully, an error will
// be returned.
func checkSCPStatus(r *bufio.Reader) error {
code, err := r.ReadByte()
if err != nil {
return err
}
if code != 0 {
// Treat any non-zero (really 1 and 2) as fatal errors
message, _, err := r.ReadLine()
if err != nil {
return fmt.Errorf("Error reading error message: %s", err)
}
return errors.New(string(message))
}
return nil
}
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
var mode os.FileMode
var size int64
if fi != nil && (*fi).Mode().IsRegular() {
mode = (*fi).Mode().Perm()
size = (*fi).Size()
} else {
// Create a temporary file where we can copy the contents of the src
// so that we can determine the length, since SCP is length-prefixed.
tf, err := ioutil.TempFile("", "vault-ssh-upload")
if err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
defer os.Remove(tf.Name())
defer tf.Close()
mode = 0644
log.Println("Copying input data into temporary file so we can read the length")
if _, err := io.Copy(tf, src); err != nil {
return err
}
// Sync the file so that the contents are definitely on disk, then
// read the length of it.
if err := tf.Sync(); err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
// Seek the file to the beginning so we can re-read all of it
if _, err := tf.Seek(0, 0); err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
tfi, err := tf.Stat()
if err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
size = tfi.Size()
src = tf
}
// Start the protocol
perms := fmt.Sprintf("C%04o", mode)
fmt.Fprintln(w, perms, size, dst)
if err := checkSCPStatus(r); err != nil {
return err
}
if _, err := io.CopyN(w, src, size); err != nil {
return err
}
fmt.Fprint(w, "\x00")
if err := checkSCPStatus(r); err != nil {
return err
}
return nil
}

View file

@ -11,17 +11,11 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
const defaultSSHLeaseDuration = 5 * time.Minute
type sshOTP struct {
Username string `json:"username"`
IP string `json:"ip"`
}
type sshCIDR struct {
CIDR []string
}
func pathCredsCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/(?P<role>[-\\w]+)",
@ -87,7 +81,7 @@ func (b *backend) pathCredsCreateWrite(
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ipRaw)), nil
}
ip := ipAddr.String()
ipMatched, err := cidrContainsIP(ip, role.CIDR)
ipMatched, err := cidrContainsIP(ip, role.CIDRList)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %s", err)), nil
}
@ -137,8 +131,8 @@ func (b *backend) pathCredsCreateWrite(
}
if lease == nil {
result.Secret.Lease = defaultSSHLeaseDuration
result.Secret.LeaseGracePeriod = 0
result.Secret.Lease = 10 * time.Minute
result.Secret.LeaseGracePeriod = 2 * time.Minute
}
return result, nil

View file

@ -34,12 +34,12 @@ func pathRoles(b *backend) *framework.Path {
Type: framework.TypeString,
Description: "Default user to whom the dynamic key is installed",
},
"cidr": &framework.FieldSchema{
"cidr_list": &framework.FieldSchema{
Type: framework.TypeString,
Description: "CIDR blocks and IP addresses",
Description: "Comma separated CIDR blocks and IP addresses",
},
"port": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeInt,
Description: "Port number for SSH connection",
},
"key_type": &framework.FieldSchema{
@ -73,20 +73,20 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
return logical.ErrorResponse("Missing role name"), nil
}
cidr := d.Get("cidr").(string)
if cidr == "" {
cidrList := d.Get("cidr_list").(string)
if cidrList == "" {
return logical.ErrorResponse("Missing CIDR blocks"), nil
}
for _, item := range strings.Split(cidr, ",") {
for _, item := range strings.Split(cidrList, ",") {
_, _, err := net.ParseCIDR(item)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil
return logical.ErrorResponse(fmt.Sprintf("Invalid CIDR list entry '%s'", item)), nil
}
}
port := d.Get("port").(string)
if port == "" {
port = "22"
port := d.Get("port").(int)
if port == 0 {
port = 22
}
keyType := d.Get("key_type").(string)
@ -110,7 +110,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
entry, err = logical.StorageEntryJSON(fmt.Sprintf("roles/%s", roleName), sshRole{
DefaultUser: defaultUser,
CIDR: cidr,
CIDRList: cidrList,
KeyType: KeyTypeOTP,
Port: port,
})
@ -154,7 +154,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
KeyName: keyName,
AdminUser: adminUser,
DefaultUser: defaultUser,
CIDR: cidr,
CIDRList: cidrList,
Port: port,
KeyType: KeyTypeDynamic,
KeyBits: keyBits,
@ -193,7 +193,7 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
return &logical.Response{
Data: map[string]interface{}{
"default_user": role.DefaultUser,
"cidr": role.CIDR,
"cidr_list": role.CIDRList,
"port": role.Port,
"key_type": role.KeyType,
},
@ -204,7 +204,7 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
"key": role.KeyName,
"admin_user": role.AdminUser,
"default_user": role.DefaultUser,
"cidr": role.CIDR,
"cidr_list": role.CIDRList,
"port": role.Port,
"key_type": role.KeyType,
},
@ -227,8 +227,8 @@ type sshRole struct {
KeyBits string `mapstructure:"key_bits" json:"key_bits"`
AdminUser string `mapstructure:"admin_user" json:"admin_user"`
DefaultUser string `mapstructure:"default_user" json:"default_user"`
CIDR string `mapstructure:"cidr" json:"cidr"`
Port string `mapstructure:"port" json:"port"`
CIDRList string `mapstructure:"cidr_list" json:"cidr_list"`
Port int `mapstructure:"port" json:"port"`
InstallScript string `mapstructure:"install_script" json:"install_script"`
}
@ -244,8 +244,8 @@ is mounted at "ssh" and the role is created at "ssh/roles/web",
then a user could request for a new key at "ssh/creds/web" for the
supplied username and IP address.
The 'cidr' field takes comma seperated CIDR blocks. The 'admin_user'
should have root access in all the hosts represented by the 'cidr'
The 'cidr_list' field takes comma seperated CIDR blocks. The 'admin_user'
should have root access in all the hosts represented by the 'cidr_list'
field. When the user requests key for an IP, the key will be installed
for the user mentioned by 'default_user' field. The 'key' field takes
a named key which can be configured by 'ssh/keys/' endpoint.
@ -268,7 +268,7 @@ Role Options:
IP address, by default, this username is used to create the
credentials. Required for 'otp' type. Optional for 'dynamic' type.
-cidr CIDR block for which is role is applicable for. Required field
-cidr_list CIDR block for which is role is applicable for. Required field
for both types.
-port Port number for SSH connections. Default is '22'. Optional for

View file

@ -102,7 +102,7 @@ func (b *backend) secretDynamicKeyRevoke(req *logical.Request, d *framework.Fiel
if !ok {
return nil, fmt.Errorf("secret is missing internal data")
}
port := portRaw.(string)
port := int(portRaw.(float64))
// Fetch the host key using the key name
hostKeyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", hostKeyName))

View file

@ -14,14 +14,13 @@ import (
"github.com/hashicorp/vault/logical"
commssh "github.com/mitchellh/packer/communicator/ssh"
"golang.org/x/crypto/ssh"
)
// Creates a SSH session object which can be used to run commands
// in the target machine. The session will use public key authentication
// method with port 22.
func createSSHPublicKeysSession(username, ipAddr, port, hostKey string) (*ssh.Session, error) {
func createSSHPublicKeysSession(username, ipAddr string, port int, hostKey string) (*ssh.Session, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
@ -43,7 +42,7 @@ func createSSHPublicKeysSession(username, ipAddr, port, hostKey string) (*ssh.Se
},
}
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", ipAddr, port), config)
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ipAddr, port), config)
if err != nil {
return nil, err
}
@ -82,7 +81,7 @@ func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, er
// Concatenates the public present in that target machine's home
// folder to ~/.ssh/authorized_keys file
func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostkey string) error {
func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostkey string) error {
session, err := createSSHPublicKeysSession(adminUser, ip, port, hostkey)
if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
@ -107,7 +106,7 @@ func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port,
// Removes the installed public key from the authorized_keys file
// in target machine
func uninstallPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostKey string) error {
func uninstallPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostKey string) error {
session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey)
if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
@ -155,7 +154,7 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
return false, fmt.Errorf("error decoding role '%s'", roleName)
}
if matched, err := cidrContainsIP(ip, role.CIDR); err != nil {
if matched, err := cidrContainsIP(ip, role.CIDRList); err != nil {
return false, err
} else {
return matched, nil
@ -164,11 +163,11 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
// Returns true if the IP supplied by the user is part of the comma
// separated CIDR blocks
func cidrContainsIP(ip, cidr string) (bool, error) {
for _, item := range strings.Split(cidr, ",") {
func cidrContainsIP(ip, cidrList string) (bool, error) {
for _, item := range strings.Split(cidrList, ",") {
_, cidrIPNet, err := net.ParseCIDR(item)
if err != nil {
return false, fmt.Errorf("invalid cidr entry '%s'", item)
return false, fmt.Errorf("invalid CIDR entry '%s'", item)
}
if cidrIPNet.Contains(net.ParseIP(ip)) {
return true, nil
@ -177,7 +176,7 @@ func cidrContainsIP(ip, cidr string) (bool, error) {
return false, nil
}
func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error {
func scpUpload(username, ip string, port int, hostkey, fileName, fileContent string) error {
signer, err := ssh.ParsePrivateKey([]byte(hostkey))
clientConfig := &ssh.ClientConfig{
User: username,
@ -187,7 +186,7 @@ func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error
}
connfunc := func() (net.Conn, error) {
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%s", ip, port), 15*time.Second)
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", ip, port), 15*time.Second)
if err != nil {
return nil, err
}
@ -199,13 +198,13 @@ func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error
return c, nil
}
config := &commssh.Config{
config := &SSHCommConfig{
SSHConfig: clientConfig,
Connection: connfunc,
Pty: false,
DisableAgent: true,
}
comm, err := commssh.New(fmt.Sprintf("%s:%s", ip, port), config)
comm, err := SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
if err != nil {
return fmt.Errorf("error connecting to target: %s", err)
}