open-nomad/e2e/framework/provisioning/ssh_runner.go

247 lines
7 KiB
Go
Raw Normal View History

package provisioning
import (
"context"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
)
// SSHRunner is a ProvisioningRunner that deploys via ssh.
// Terraform does all of this more elegantly and portably in its
// ssh communicator, but by shelling out we avoid pulling in TF's as
// a Nomad dependency, and avoid some long-standing issues with
// connections to Windows servers. The tradeoff is losing portability
// but in practice we're always going to run this from a Unixish
// machine.
type SSHRunner struct {
Key string // `json:"key"`
User string // `json:"user"`
Host string // `json:"host"`
Port int // `json:"port"`
// none of these are available at time of construction, but
// should be populated in Open().
t *testing.T
controlSockPath string
ctx context.Context
cancelFunc context.CancelFunc
copyMethod func(*SSHRunner, string, string) error
muxWait chan struct{}
}
// Open establishes the ssh connection. We keep this connection open
// so that we can multiplex subsequent ssh connections.
func (runner *SSHRunner) Open(t *testing.T) error {
runner.t = t
runner.Logf("opening connection to %s", runner.Host)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
runner.ctx = ctx
runner.cancelFunc = cancel
runner.muxWait = make(chan struct{})
home, _ := os.UserHomeDir()
runner.controlSockPath = filepath.Join(
home, ".ssh",
fmt.Sprintf("ssh-control-%s-%d.sock", runner.Host, os.Getpid()))
cmd := exec.CommandContext(ctx,
"ssh",
"-M", "-S", runner.controlSockPath,
"-o", "StrictHostKeyChecking=no", // we're those terrible cloud devs
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-o", "ConnectTimeout=60", // give the target a while to come up
"-i", runner.Key,
"-p", fmt.Sprintf("%v", runner.Port),
fmt.Sprintf("%s@%s", runner.User, runner.Host),
)
go func() {
// will block until command completes, we cancel, or timeout.
// there's no point in returning the error here as we only
// hit it when we're done and Windows unfortunately tends to
// return 1 even when the script is complete.
cmd.Run()
runner.muxWait <- struct{}{}
}()
return nil
}
func (runner *SSHRunner) Run(script string) error {
commands := strings.Split(strings.TrimSpace(script), "\n")
for _, command := range commands {
err := runner.run(strings.TrimSpace(command))
if err != nil {
runner.cancelFunc()
return err
}
}
return nil
}
func (runner *SSHRunner) run(command string) error {
if runner.controlSockPath == "" {
return fmt.Errorf("Run failed: you need to call Open() first")
}
runner.Logf("running '%s'", command)
cmd := exec.CommandContext(runner.ctx,
"ssh",
"-S", runner.controlSockPath,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-p", fmt.Sprintf("%v", runner.Port),
fmt.Sprintf("%s@%s", runner.User, runner.Host),
command)
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
return err
}
runner.LogOutput(string(stdoutStderr))
return nil
}
// Copy uploads the local path to the remote path. We call into
// different copy methods for Linux vs Windows because their path
// semantics are slightly different and the typical ssh users have
// different permissions.
func (runner *SSHRunner) Copy(local, remote string) error {
return runner.copyMethod(runner, local, remote)
}
// TODO: would be nice to set file owner/mode here
func copyLinux(runner *SSHRunner, local, remote string) error {
t := runner.t
runner.Logf("copying '%s' to '%s'", local, remote)
remoteDir, remoteFileName := filepath.Split(remote)
// we stage to /tmp so that we can handle root-owned files
tempPath := fmt.Sprintf("/tmp/%s", remoteFileName)
cmd := exec.CommandContext(runner.ctx,
"scp", "-r",
"-o", fmt.Sprintf("ControlPath=%s", runner.controlSockPath),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-P", fmt.Sprintf("%v", runner.Port),
local,
fmt.Sprintf("%s@%s:%s", runner.User, runner.Host, tempPath))
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
runner.cancelFunc()
return err
}
fi, err := os.Stat(local)
if err != nil {
t.Fatalf("could not read '%s'", local)
}
if fi.IsDir() {
// this is a little inefficient but it lets us merge the contents of
// a bundled directory with existing directories
err = runner.Run(
fmt.Sprintf("sudo mkdir -p %s; sudo cp -R %s %s; sudo rm -r %s",
remote, tempPath, remoteDir, tempPath))
} else {
err = runner.run(fmt.Sprintf("sudo mv %s %s", tempPath, remoteDir))
}
return err
}
// staging to Windows tempdirs is a little messier, but "fortunately"
// nobody seems to complain about connecting via ssh as Administrator on
// Windows so we can just bypass the problem.
func copyWindows(runner *SSHRunner, local, remote string) error {
runner.Logf("copying '%s' to '%s'", local, remote)
remoteDir, _ := filepath.Split(remote)
fi, err := os.Stat(local)
if err != nil {
runner.t.Fatalf("could not read '%s'", local)
}
remotePath := remote
if fi.IsDir() {
remotePath = remoteDir
}
cmd := exec.CommandContext(runner.ctx,
"scp", "-r",
"-o", fmt.Sprintf("ControlPath=%s", runner.controlSockPath),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-P", fmt.Sprintf("%v", runner.Port),
local,
fmt.Sprintf("%s@%s:'%s'", runner.User, runner.Host, remotePath))
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
runner.cancelFunc()
return err
}
return err
}
func (runner *SSHRunner) Close() {
runner.Log("closing connection")
runner.cancelFunc()
<-runner.muxWait
}
// 'go test -v' only emits logs after the entire test run is complete,
// but that makes it much harder to debug hanging deployments. These
// methods wrap the test logger or just emit directly w/ fmt.Print if
// the '-v' flag was set.
func (runner *SSHRunner) Log(args ...interface{}) {
if runner.t == nil {
log.Fatal("no t.Testing configured for SSHRunner")
}
if testing.Verbose() {
fmt.Printf("[" + runner.Host + "] ")
fmt.Println(args...)
} else {
runner.t.Log(args...)
}
}
func (runner *SSHRunner) Logf(format string, args ...interface{}) {
if runner.t == nil {
log.Fatal("no t.Testing configured for SSHRunner")
}
if testing.Verbose() {
fmt.Printf("["+runner.Host+"] "+format+"\n", args...)
} else {
runner.t.Logf("["+runner.Host+"] "+format, args...)
}
}
func (runner *SSHRunner) LogOutput(output string) {
if testing.Verbose() {
fmt.Println("\033[32m" + output + "\033[0m")
} else {
runner.t.Log(output)
}
}
func (runner *SSHRunner) LogErrOutput(output string) {
if testing.Verbose() {
fmt.Println("\033[31m" + output + "\033[0m")
} else {
runner.t.Log(output)
}
}