Vault SSH: cidr to cidr_list
This commit is contained in:
parent
7d3025fd6e
commit
8e946f27cc
|
@ -18,7 +18,7 @@ import (
|
|||
const (
|
||||
testOTPKeyType = "otp"
|
||||
testDynamicKeyType = "dynamic"
|
||||
testCidr = "127.0.0.1/32"
|
||||
testCIDRList = "127.0.0.1/32"
|
||||
testDynamicRoleName = "testDynamicRoleName"
|
||||
testOTPRoleName = "testOTPRoleName"
|
||||
testKeyName = "testKeyName"
|
||||
|
@ -55,7 +55,7 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
|
|||
|
||||
var testIP string
|
||||
var testOTP string
|
||||
var testPort string
|
||||
var testPort int
|
||||
var testUserName string
|
||||
var testAdminUser string
|
||||
var testInstallScript string
|
||||
|
@ -92,13 +92,13 @@ func TestSSHBackend_Lookup(t *testing.T) {
|
|||
otpData := map[string]interface{}{
|
||||
"key_type": testOTPKeyType,
|
||||
"default_user": testUserName,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
}
|
||||
dynamicData := map[string]interface{}{
|
||||
"key_type": testDynamicKeyType,
|
||||
"key": testKeyName,
|
||||
"admin_user": testAdminUser,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
"install_script": testInstallScript,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
|
@ -133,7 +133,7 @@ func TestSSHBackend_OTPRoleCrud(t *testing.T) {
|
|||
data := map[string]interface{}{
|
||||
"key_type": testOTPKeyType,
|
||||
"default_user": testUserName,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Factory: Factory,
|
||||
|
@ -151,7 +151,7 @@ func TestSSHBackend_DynamicRoleCrud(t *testing.T) {
|
|||
"key_type": testDynamicKeyType,
|
||||
"key": testKeyName,
|
||||
"admin_user": testAdminUser,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
"install_script": testInstallScript,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
|
@ -182,7 +182,7 @@ func TestSSHBackend_OTPCreate(t *testing.T) {
|
|||
data := map[string]interface{}{
|
||||
"key_type": testOTPKeyType,
|
||||
"default_user": testUserName,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Factory: Factory,
|
||||
|
@ -303,11 +303,11 @@ func testRoleRead(t *testing.T, name string, data map[string]interface{}) logica
|
|||
return fmt.Errorf("error decoding response:%s", err)
|
||||
}
|
||||
if name == testOTPRoleName {
|
||||
if d.KeyType != data["key_type"] || d.DefaultUser != data["default_user"] || d.CIDR != data["cidr"] {
|
||||
if d.KeyType != data["key_type"] || d.DefaultUser != data["default_user"] || d.CIDRList != data["cidr_list"] {
|
||||
return fmt.Errorf("data mismatch. bad: %#v", resp)
|
||||
}
|
||||
} else {
|
||||
if d.AdminUser != data["admin_user"] || d.CIDR != data["cidr"] || d.KeyName != data["key"] || d.KeyType != data["key_type"] {
|
||||
if d.AdminUser != data["admin_user"] || d.CIDRList != data["cidr_list"] || d.KeyName != data["key"] || d.KeyType != data["key_type"] {
|
||||
return fmt.Errorf("data mismatch. bad: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ func testNewDynamicKeyRole(t *testing.T) logicaltest.TestStep {
|
|||
"key_type": "dynamic",
|
||||
"key": testKeyName,
|
||||
"admin_user": testAdminUser,
|
||||
"cidr": testCidr,
|
||||
"cidr_list": testCIDRList,
|
||||
"port": testPort,
|
||||
"install_script": testInstallScript,
|
||||
},
|
||||
|
|
338
builtin/logical/ssh/communicator.go
Normal file
338
builtin/logical/ssh/communicator.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
type comm struct {
|
||||
client *ssh.Client
|
||||
config *SSHCommConfig
|
||||
conn net.Conn
|
||||
address string
|
||||
}
|
||||
|
||||
// SSHCommConfig is the structure used to configure the SSH communicator.
|
||||
type SSHCommConfig struct {
|
||||
// The configuration of the Go SSH connection
|
||||
SSHConfig *ssh.ClientConfig
|
||||
|
||||
// Connection returns a new connection. The current connection
|
||||
// in use will be closed as part of the Close method, or in the
|
||||
// case an error occurs.
|
||||
Connection func() (net.Conn, error)
|
||||
|
||||
// Pty, if true, will request a pty from the remote end.
|
||||
Pty bool
|
||||
|
||||
// DisableAgent, if true, will not forward the SSH agent.
|
||||
DisableAgent bool
|
||||
}
|
||||
|
||||
// Creates a new communicator implementation over SSH. This takes
|
||||
// an already existing TCP connection and SSH configuration.
|
||||
func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error) {
|
||||
// Establish an initial connection and connect
|
||||
result = &comm{
|
||||
config: config,
|
||||
address: address,
|
||||
}
|
||||
|
||||
if err = result.reconnect(); err != nil {
|
||||
result = nil
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
|
||||
// The target directory and file for talking the SCP protocol
|
||||
target_dir := filepath.Dir(path)
|
||||
target_file := filepath.Base(path)
|
||||
|
||||
// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
|
||||
// This does not work when the target host is unix. Switch to forward slash
|
||||
// which works for unix and windows
|
||||
target_dir = filepath.ToSlash(target_dir)
|
||||
|
||||
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
|
||||
return scpUploadFile(target_file, input, w, stdoutR, fi)
|
||||
}
|
||||
|
||||
return c.scpSession("scp -vt "+target_dir, scpFunc)
|
||||
}
|
||||
|
||||
func (c *comm) newSession() (session *ssh.Session, err error) {
|
||||
if c.client == nil {
|
||||
err = errors.New("client not available")
|
||||
} else {
|
||||
session, err = c.client.NewSession()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("ssh session open error: '%s', attempting reconnect", err)
|
||||
if err := c.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.client.NewSession()
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *comm) reconnect() (err error) {
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// Set the conn and client to nil since we'll recreate it
|
||||
c.conn = nil
|
||||
c.client = nil
|
||||
|
||||
c.conn, err = c.config.Connection()
|
||||
if err != nil {
|
||||
// Explicitly set this to the REAL nil. Connection() can return
|
||||
// a nil implementation of net.Conn which will make the
|
||||
// "if c.conn == nil" check fail above. Read here for more information
|
||||
// on this psychotic language feature:
|
||||
//
|
||||
// http://golang.org/doc/faq#nil_error
|
||||
c.conn = nil
|
||||
log.Printf("reconnection error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
|
||||
if err != nil {
|
||||
log.Printf("handshake error: %s", err)
|
||||
}
|
||||
if sshConn != nil {
|
||||
c.client = ssh.NewClient(sshConn, sshChan, req)
|
||||
}
|
||||
c.connectToAgent()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *comm) connectToAgent() {
|
||||
if c.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.config.DisableAgent {
|
||||
return
|
||||
}
|
||||
|
||||
// open connection to the local agent
|
||||
socketLocation := os.Getenv("SSH_AUTH_SOCK")
|
||||
if socketLocation == "" {
|
||||
return
|
||||
}
|
||||
agentConn, err := net.Dial("unix", socketLocation)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
|
||||
return
|
||||
}
|
||||
|
||||
// create agent and add in auth
|
||||
forwardingAgent := agent.NewClient(agentConn)
|
||||
if forwardingAgent == nil {
|
||||
log.Printf("[ERROR] Could not create agent client")
|
||||
agentConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// add callback for forwarding agent to SSH config
|
||||
// XXX - might want to handle reconnects appending multiple callbacks
|
||||
auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
|
||||
c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
|
||||
agent.ForwardToAgent(c.client, forwardingAgent)
|
||||
|
||||
// Setup a session to request agent forwarding
|
||||
session, err := c.newSession()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
err = agent.RequestAgentForwarding(session)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
|
||||
session, err := c.newSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Get a pipe to stdin so that we can send data down
|
||||
stdinW, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only want to close once, so we nil w after we close it,
|
||||
// and only close in the defer if it hasn't been closed already.
|
||||
defer func() {
|
||||
if stdinW != nil {
|
||||
stdinW.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Get a pipe to stdout so that we can get responses back
|
||||
stdoutPipe, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdoutR := bufio.NewReader(stdoutPipe)
|
||||
|
||||
// Set stderr to a bytes buffer
|
||||
stderr := new(bytes.Buffer)
|
||||
session.Stderr = stderr
|
||||
|
||||
// Start the sink mode on the other side
|
||||
if err := session.Start(scpCommand); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Call our callback that executes in the context of SCP. We ignore
|
||||
// EOF errors if they occur because it usually means that SCP prematurely
|
||||
// ended on the other side.
|
||||
if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the stdin, which sends an EOF, and then set w to nil so that
|
||||
// our defer func doesn't close it again since that is unsafe with
|
||||
// the Go SSH package.
|
||||
stdinW.Close()
|
||||
stdinW = nil
|
||||
|
||||
// Wait for the SCP connection to close, meaning it has consumed all
|
||||
// our data and has completed. Or has errored.
|
||||
err = session.Wait()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*ssh.ExitError); ok {
|
||||
// Otherwise, we have an ExitErorr, meaning we can just read
|
||||
// the exit status
|
||||
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
|
||||
|
||||
// If we exited with status 127, it means SCP isn't available.
|
||||
// Return a more descriptive error for that.
|
||||
if exitErr.ExitStatus() == 127 {
|
||||
return errors.New(
|
||||
"SCP failed to start. This usually means that SCP is not\n" +
|
||||
"properly installed on the remote system.")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSCPStatus checks that a prior command sent to SCP completed
|
||||
// successfully. If it did not complete successfully, an error will
|
||||
// be returned.
|
||||
func checkSCPStatus(r *bufio.Reader) error {
|
||||
code, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if code != 0 {
|
||||
// Treat any non-zero (really 1 and 2) as fatal errors
|
||||
message, _, err := r.ReadLine()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error reading error message: %s", err)
|
||||
}
|
||||
|
||||
return errors.New(string(message))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
|
||||
var mode os.FileMode
|
||||
var size int64
|
||||
|
||||
if fi != nil && (*fi).Mode().IsRegular() {
|
||||
mode = (*fi).Mode().Perm()
|
||||
size = (*fi).Size()
|
||||
} else {
|
||||
// Create a temporary file where we can copy the contents of the src
|
||||
// so that we can determine the length, since SCP is length-prefixed.
|
||||
tf, err := ioutil.TempFile("", "vault-ssh-upload")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating temporary file for upload: %s", err)
|
||||
}
|
||||
defer os.Remove(tf.Name())
|
||||
defer tf.Close()
|
||||
|
||||
mode = 0644
|
||||
|
||||
log.Println("Copying input data into temporary file so we can read the length")
|
||||
if _, err := io.Copy(tf, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sync the file so that the contents are definitely on disk, then
|
||||
// read the length of it.
|
||||
if err := tf.Sync(); err != nil {
|
||||
return fmt.Errorf("Error creating temporary file for upload: %s", err)
|
||||
}
|
||||
|
||||
// Seek the file to the beginning so we can re-read all of it
|
||||
if _, err := tf.Seek(0, 0); err != nil {
|
||||
return fmt.Errorf("Error creating temporary file for upload: %s", err)
|
||||
}
|
||||
|
||||
tfi, err := tf.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating temporary file for upload: %s", err)
|
||||
}
|
||||
|
||||
size = tfi.Size()
|
||||
src = tf
|
||||
}
|
||||
|
||||
// Start the protocol
|
||||
perms := fmt.Sprintf("C%04o", mode)
|
||||
|
||||
fmt.Fprintln(w, perms, size, dst)
|
||||
if err := checkSCPStatus(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.CopyN(w, src, size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(w, "\x00")
|
||||
if err := checkSCPStatus(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -11,17 +11,11 @@ import (
|
|||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const defaultSSHLeaseDuration = 5 * time.Minute
|
||||
|
||||
type sshOTP struct {
|
||||
Username string `json:"username"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
type sshCIDR struct {
|
||||
CIDR []string
|
||||
}
|
||||
|
||||
func pathCredsCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/(?P<role>[-\\w]+)",
|
||||
|
@ -87,7 +81,7 @@ func (b *backend) pathCredsCreateWrite(
|
|||
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ipRaw)), nil
|
||||
}
|
||||
ip := ipAddr.String()
|
||||
ipMatched, err := cidrContainsIP(ip, role.CIDR)
|
||||
ipMatched, err := cidrContainsIP(ip, role.CIDRList)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %s", err)), nil
|
||||
}
|
||||
|
@ -137,8 +131,8 @@ func (b *backend) pathCredsCreateWrite(
|
|||
}
|
||||
|
||||
if lease == nil {
|
||||
result.Secret.Lease = defaultSSHLeaseDuration
|
||||
result.Secret.LeaseGracePeriod = 0
|
||||
result.Secret.Lease = 10 * time.Minute
|
||||
result.Secret.LeaseGracePeriod = 2 * time.Minute
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
|
|
@ -34,12 +34,12 @@ func pathRoles(b *backend) *framework.Path {
|
|||
Type: framework.TypeString,
|
||||
Description: "Default user to whom the dynamic key is installed",
|
||||
},
|
||||
"cidr": &framework.FieldSchema{
|
||||
"cidr_list": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "CIDR blocks and IP addresses",
|
||||
Description: "Comma separated CIDR blocks and IP addresses",
|
||||
},
|
||||
"port": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Type: framework.TypeInt,
|
||||
Description: "Port number for SSH connection",
|
||||
},
|
||||
"key_type": &framework.FieldSchema{
|
||||
|
@ -73,20 +73,20 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
|
|||
return logical.ErrorResponse("Missing role name"), nil
|
||||
}
|
||||
|
||||
cidr := d.Get("cidr").(string)
|
||||
if cidr == "" {
|
||||
cidrList := d.Get("cidr_list").(string)
|
||||
if cidrList == "" {
|
||||
return logical.ErrorResponse("Missing CIDR blocks"), nil
|
||||
}
|
||||
for _, item := range strings.Split(cidr, ",") {
|
||||
for _, item := range strings.Split(cidrList, ",") {
|
||||
_, _, err := net.ParseCIDR(item)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("Invalid CIDR list entry '%s'", item)), nil
|
||||
}
|
||||
}
|
||||
|
||||
port := d.Get("port").(string)
|
||||
if port == "" {
|
||||
port = "22"
|
||||
port := d.Get("port").(int)
|
||||
if port == 0 {
|
||||
port = 22
|
||||
}
|
||||
|
||||
keyType := d.Get("key_type").(string)
|
||||
|
@ -110,7 +110,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
|
|||
|
||||
entry, err = logical.StorageEntryJSON(fmt.Sprintf("roles/%s", roleName), sshRole{
|
||||
DefaultUser: defaultUser,
|
||||
CIDR: cidr,
|
||||
CIDRList: cidrList,
|
||||
KeyType: KeyTypeOTP,
|
||||
Port: port,
|
||||
})
|
||||
|
@ -154,7 +154,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
|
|||
KeyName: keyName,
|
||||
AdminUser: adminUser,
|
||||
DefaultUser: defaultUser,
|
||||
CIDR: cidr,
|
||||
CIDRList: cidrList,
|
||||
Port: port,
|
||||
KeyType: KeyTypeDynamic,
|
||||
KeyBits: keyBits,
|
||||
|
@ -193,7 +193,7 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
|
|||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr": role.CIDR,
|
||||
"cidr_list": role.CIDRList,
|
||||
"port": role.Port,
|
||||
"key_type": role.KeyType,
|
||||
},
|
||||
|
@ -204,7 +204,7 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
|
|||
"key": role.KeyName,
|
||||
"admin_user": role.AdminUser,
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr": role.CIDR,
|
||||
"cidr_list": role.CIDRList,
|
||||
"port": role.Port,
|
||||
"key_type": role.KeyType,
|
||||
},
|
||||
|
@ -227,8 +227,8 @@ type sshRole struct {
|
|||
KeyBits string `mapstructure:"key_bits" json:"key_bits"`
|
||||
AdminUser string `mapstructure:"admin_user" json:"admin_user"`
|
||||
DefaultUser string `mapstructure:"default_user" json:"default_user"`
|
||||
CIDR string `mapstructure:"cidr" json:"cidr"`
|
||||
Port string `mapstructure:"port" json:"port"`
|
||||
CIDRList string `mapstructure:"cidr_list" json:"cidr_list"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
InstallScript string `mapstructure:"install_script" json:"install_script"`
|
||||
}
|
||||
|
||||
|
@ -244,8 +244,8 @@ is mounted at "ssh" and the role is created at "ssh/roles/web",
|
|||
then a user could request for a new key at "ssh/creds/web" for the
|
||||
supplied username and IP address.
|
||||
|
||||
The 'cidr' field takes comma seperated CIDR blocks. The 'admin_user'
|
||||
should have root access in all the hosts represented by the 'cidr'
|
||||
The 'cidr_list' field takes comma seperated CIDR blocks. The 'admin_user'
|
||||
should have root access in all the hosts represented by the 'cidr_list'
|
||||
field. When the user requests key for an IP, the key will be installed
|
||||
for the user mentioned by 'default_user' field. The 'key' field takes
|
||||
a named key which can be configured by 'ssh/keys/' endpoint.
|
||||
|
@ -268,7 +268,7 @@ Role Options:
|
|||
IP address, by default, this username is used to create the
|
||||
credentials. Required for 'otp' type. Optional for 'dynamic' type.
|
||||
|
||||
-cidr CIDR block for which is role is applicable for. Required field
|
||||
-cidr_list CIDR block for which is role is applicable for. Required field
|
||||
for both types.
|
||||
|
||||
-port Port number for SSH connections. Default is '22'. Optional for
|
||||
|
|
|
@ -102,7 +102,7 @@ func (b *backend) secretDynamicKeyRevoke(req *logical.Request, d *framework.Fiel
|
|||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing internal data")
|
||||
}
|
||||
port := portRaw.(string)
|
||||
port := int(portRaw.(float64))
|
||||
|
||||
// Fetch the host key using the key name
|
||||
hostKeyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", hostKeyName))
|
||||
|
|
|
@ -14,14 +14,13 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
||||
commssh "github.com/mitchellh/packer/communicator/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Creates a SSH session object which can be used to run commands
|
||||
// in the target machine. The session will use public key authentication
|
||||
// method with port 22.
|
||||
func createSSHPublicKeysSession(username, ipAddr, port, hostKey string) (*ssh.Session, error) {
|
||||
func createSSHPublicKeysSession(username, ipAddr string, port int, hostKey string) (*ssh.Session, error) {
|
||||
if username == "" {
|
||||
return nil, fmt.Errorf("missing username")
|
||||
}
|
||||
|
@ -43,7 +42,7 @@ func createSSHPublicKeysSession(username, ipAddr, port, hostKey string) (*ssh.Se
|
|||
},
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", ipAddr, port), config)
|
||||
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ipAddr, port), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -82,7 +81,7 @@ func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, er
|
|||
|
||||
// Concatenates the public present in that target machine's home
|
||||
// folder to ~/.ssh/authorized_keys file
|
||||
func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostkey string) error {
|
||||
func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostkey string) error {
|
||||
session, err := createSSHPublicKeysSession(adminUser, ip, port, hostkey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
|
||||
|
@ -107,7 +106,7 @@ func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port,
|
|||
|
||||
// Removes the installed public key from the authorized_keys file
|
||||
// in target machine
|
||||
func uninstallPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostKey string) error {
|
||||
func uninstallPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostKey string) error {
|
||||
session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
|
||||
|
@ -155,7 +154,7 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
|
|||
return false, fmt.Errorf("error decoding role '%s'", roleName)
|
||||
}
|
||||
|
||||
if matched, err := cidrContainsIP(ip, role.CIDR); err != nil {
|
||||
if matched, err := cidrContainsIP(ip, role.CIDRList); err != nil {
|
||||
return false, err
|
||||
} else {
|
||||
return matched, nil
|
||||
|
@ -164,11 +163,11 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
|
|||
|
||||
// Returns true if the IP supplied by the user is part of the comma
|
||||
// separated CIDR blocks
|
||||
func cidrContainsIP(ip, cidr string) (bool, error) {
|
||||
for _, item := range strings.Split(cidr, ",") {
|
||||
func cidrContainsIP(ip, cidrList string) (bool, error) {
|
||||
for _, item := range strings.Split(cidrList, ",") {
|
||||
_, cidrIPNet, err := net.ParseCIDR(item)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid cidr entry '%s'", item)
|
||||
return false, fmt.Errorf("invalid CIDR entry '%s'", item)
|
||||
}
|
||||
if cidrIPNet.Contains(net.ParseIP(ip)) {
|
||||
return true, nil
|
||||
|
@ -177,7 +176,7 @@ func cidrContainsIP(ip, cidr string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error {
|
||||
func scpUpload(username, ip string, port int, hostkey, fileName, fileContent string) error {
|
||||
signer, err := ssh.ParsePrivateKey([]byte(hostkey))
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: username,
|
||||
|
@ -187,7 +186,7 @@ func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error
|
|||
}
|
||||
|
||||
connfunc := func() (net.Conn, error) {
|
||||
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%s", ip, port), 15*time.Second)
|
||||
c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", ip, port), 15*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -199,13 +198,13 @@ func scpUpload(username, ip, port, hostkey, fileName, fileContent string) error
|
|||
|
||||
return c, nil
|
||||
}
|
||||
config := &commssh.Config{
|
||||
config := &SSHCommConfig{
|
||||
SSHConfig: clientConfig,
|
||||
Connection: connfunc,
|
||||
Pty: false,
|
||||
DisableAgent: true,
|
||||
}
|
||||
comm, err := commssh.New(fmt.Sprintf("%s:%s", ip, port), config)
|
||||
comm, err := SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to target: %s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue