open-nomad/client/allocrunner/taskrunner/getter/getter_test.go
Luiz Aoqui 5c100c0d3d
client: recover from getter panics (#14696)
The artifact getter uses the go-getter library to fetch files from
different sources. Any bug in this library that results in a panic can
cause the entire Nomad client to crash due to a single file download
attempt.

This change aims to guard against this types of crashes by recovering
from panics when the getter attempts to download an artifact. The
resulting panic is converted to an error that is stored as a task event
for operator visibility and the panic stack trace is logged to the
client's log.
2022-09-26 15:16:26 -04:00

566 lines
16 KiB
Go

package getter
import (
"fmt"
"io"
"io/ioutil"
"mime"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"
gg "github.com/hashicorp/go-getter"
"github.com/hashicorp/go-hclog"
clientconfig "github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/interfaces"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/helper/escapingfs"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/stretchr/testify/require"
)
// noopReplacer is a noop version of taskenv.TaskEnv.ReplaceEnv.
type noopReplacer struct {
taskDir string
}
func clientPath(taskDir, path string, join bool) (string, bool) {
if !filepath.IsAbs(path) || (escapingfs.PathEscapesSandbox(taskDir, path) && join) {
path = filepath.Join(taskDir, path)
}
path = filepath.Clean(path)
if taskDir != "" && !escapingfs.PathEscapesSandbox(taskDir, path) {
return path, false
}
return path, true
}
func (noopReplacer) ReplaceEnv(s string) string {
return s
}
func (r noopReplacer) ClientPath(p string, join bool) (string, bool) {
path, escapes := clientPath(r.taskDir, r.ReplaceEnv(p), join)
return path, escapes
}
func noopTaskEnv(taskDir string) interfaces.EnvReplacer {
return noopReplacer{
taskDir: taskDir,
}
}
// panicReplacer is a version of taskenv.TaskEnv.ReplaceEnv that panics.
type panicReplacer struct{}
func (panicReplacer) ReplaceEnv(_ string) string {
panic("panic")
}
func (panicReplacer) ClientPath(_ string, _ bool) (string, bool) {
panic("panic")
}
func panicTaskEnv() interfaces.EnvReplacer {
return panicReplacer{}
}
// upperReplacer is a version of taskenv.TaskEnv.ReplaceEnv that upper-cases
// the given input.
type upperReplacer struct {
taskDir string
}
func (upperReplacer) ReplaceEnv(s string) string {
return strings.ToUpper(s)
}
func (u upperReplacer) ClientPath(p string, join bool) (string, bool) {
path, escapes := clientPath(u.taskDir, u.ReplaceEnv(p), join)
return path, escapes
}
func TestGetter_getClient(t *testing.T) {
getter := NewGetter(hclog.NewNullLogger(), &clientconfig.ArtifactConfig{
HTTPReadTimeout: time.Minute,
HTTPMaxBytes: 100_000,
GCSTimeout: 1 * time.Minute,
GitTimeout: 2 * time.Minute,
HgTimeout: 3 * time.Minute,
S3Timeout: 4 * time.Minute,
})
client := getter.getClient("src", nil, gg.ClientModeAny, "dst")
t.Run("check symlink config", func(t *testing.T) {
require.True(t, client.DisableSymlinks)
})
t.Run("check http config", func(t *testing.T) {
require.True(t, client.Getters["http"].(*gg.HttpGetter).XTerraformGetDisabled)
require.Equal(t, time.Minute, client.Getters["http"].(*gg.HttpGetter).ReadTimeout)
require.Equal(t, int64(100_000), client.Getters["http"].(*gg.HttpGetter).MaxBytes)
})
t.Run("check https config", func(t *testing.T) {
require.True(t, client.Getters["https"].(*gg.HttpGetter).XTerraformGetDisabled)
require.Equal(t, time.Minute, client.Getters["https"].(*gg.HttpGetter).ReadTimeout)
require.Equal(t, int64(100_000), client.Getters["https"].(*gg.HttpGetter).MaxBytes)
})
t.Run("check gcs config", func(t *testing.T) {
require.Equal(t, client.Getters["gcs"].(*gg.GCSGetter).Timeout, 1*time.Minute)
})
t.Run("check git config", func(t *testing.T) {
require.Equal(t, client.Getters["git"].(*gg.GitGetter).Timeout, 2*time.Minute)
})
t.Run("check hg config", func(t *testing.T) {
require.Equal(t, client.Getters["hg"].(*gg.HgGetter).Timeout, 3*time.Minute)
})
t.Run("check s3 config", func(t *testing.T) {
require.Equal(t, client.Getters["s3"].(*gg.S3Getter).Timeout, 4*time.Minute)
})
}
func TestGetArtifact_getHeaders(t *testing.T) {
t.Run("nil", func(t *testing.T) {
require.Nil(t, getHeaders(noopTaskEnv(""), nil))
})
t.Run("empty", func(t *testing.T) {
require.Nil(t, getHeaders(noopTaskEnv(""), make(map[string]string)))
})
t.Run("set", func(t *testing.T) {
upperTaskEnv := new(upperReplacer)
expected := make(http.Header)
expected.Set("foo", "BAR")
result := getHeaders(upperTaskEnv, map[string]string{
"foo": "bar",
})
require.Equal(t, expected, result)
})
}
func TestGetArtifact_Headers(t *testing.T) {
file := "output.txt"
// Create the test server with a handler that will validate headers are set.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate the expected value for our header.
value := r.Header.Get("X-Some-Value")
require.Equal(t, "FOOBAR", value)
// Write the value to the file that is our artifact, for fun.
w.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(file)))
w.WriteHeader(http.StatusOK)
_, err := io.Copy(w, strings.NewReader(value))
require.NoError(t, err)
}))
defer ts.Close()
// Create a temp directory to download into.
taskDir := t.TempDir()
// Create the artifact.
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterHeaders: map[string]string{
"X-Some-Value": "foobar",
},
RelativeDest: file,
GetterMode: "file",
}
// Download the artifact.
getter := TestDefaultGetter(t)
taskEnv := upperReplacer{
taskDir: taskDir,
}
err := getter.GetArtifact(taskEnv, artifact)
require.NoError(t, err)
// Verify artifact exists.
b, err := ioutil.ReadFile(filepath.Join(taskDir, taskEnv.ReplaceEnv(file)))
require.NoError(t, err)
// Verify we wrote the interpolated header value into the file that is our
// artifact.
require.Equal(t, "FOOBAR", string(b))
}
func TestGetArtifact_FileAndChecksum(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into
taskDir := t.TempDir()
// Create the artifact
file := "test.sh"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
},
}
// Download the artifact
getter := TestDefaultGetter(t)
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}
// Verify artifact exists
if _, err := os.Stat(filepath.Join(taskDir, file)); err != nil {
t.Fatalf("file not found: %s", err)
}
}
func TestGetArtifact_File_RelativeDest(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into
taskDir := t.TempDir()
// Create the artifact
file := "test.sh"
relative := "foo/"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
},
RelativeDest: relative,
}
// Download the artifact
getter := TestDefaultGetter(t)
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}
// Verify artifact was downloaded to the correct path
if _, err := os.Stat(filepath.Join(taskDir, relative, file)); err != nil {
t.Fatalf("file not found: %s", err)
}
}
func TestGetArtifact_File_EscapeDest(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into
taskDir := t.TempDir()
// Create the artifact
file := "test.sh"
relative := "../../../../foo/"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
},
RelativeDest: relative,
}
// attempt to download the artifact
getter := TestDefaultGetter(t)
err := getter.GetArtifact(noopTaskEnv(taskDir), artifact)
if err == nil || !strings.Contains(err.Error(), "escapes") {
t.Fatalf("expected GetArtifact to disallow sandbox escape: %v", err)
}
}
func TestGetGetterUrl_Interpolation(t *testing.T) {
// Create the artifact
artifact := &structs.TaskArtifact{
GetterSource: "${NOMAD_META_ARTIFACT}",
}
url := "foo.com"
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Meta = map[string]string{"artifact": url}
taskEnv := taskenv.NewBuilder(mock.Node(), alloc, task, "global").Build()
act, err := getGetterUrl(taskEnv, artifact)
if err != nil {
t.Fatalf("getGetterUrl() failed: %v", err)
}
if act != url {
t.Fatalf("getGetterUrl() returned %q; want %q", act, url)
}
}
func TestGetArtifact_InvalidChecksum(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into
taskDir := t.TempDir()
// Create the artifact with an incorrect checksum
file := "test.sh"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "md5:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
}
// Download the artifact and expect an error
getter := TestDefaultGetter(t)
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err == nil {
t.Fatalf("GetArtifact should have failed")
}
}
func createContents(basedir string, fileContents map[string]string, t *testing.T) {
for relPath, content := range fileContents {
folder := basedir
if strings.Index(relPath, "/") != -1 {
// Create the folder.
folder = filepath.Join(basedir, filepath.Dir(relPath))
if err := os.Mkdir(folder, 0777); err != nil {
t.Fatalf("failed to make directory: %v", err)
}
}
// Create a file in the existing folder.
file := filepath.Join(folder, filepath.Base(relPath))
if err := ioutil.WriteFile(file, []byte(content), 0777); err != nil {
t.Fatalf("failed to write data to file %v: %v", file, err)
}
}
}
func checkContents(basedir string, fileContents map[string]string, t *testing.T) {
for relPath, content := range fileContents {
path := filepath.Join(basedir, relPath)
actual, err := ioutil.ReadFile(path)
if err != nil {
t.Fatalf("failed to read file %q: %v", path, err)
}
if !reflect.DeepEqual(actual, []byte(content)) {
t.Fatalf("%q: expected %q; got %q", path, content, string(actual))
}
}
}
func TestGetArtifact_Archive(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into and create some of the same
// files that exist in the artifact to ensure they are overridden
taskDir := t.TempDir()
create := map[string]string{
"exist/my.config": "to be replaced",
"untouched": "existing top-level",
}
createContents(taskDir, create, t)
file := "archive.tar.gz"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "sha1:20bab73c72c56490856f913cf594bad9a4d730f6",
},
}
getter := TestDefaultGetter(t)
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}
// Verify the unarchiving overrode files properly.
expected := map[string]string{
"untouched": "existing top-level",
"exist/my.config": "hello world\n",
"new/my.config": "hello world\n",
"test.sh": "sleep 1\n",
}
checkContents(taskDir, expected, t)
}
func TestGetArtifact_Setuid(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
defer ts.Close()
// Create a temp directory to download into and create some of the same
// files that exist in the artifact to ensure they are overridden
taskDir := t.TempDir()
file := "setuid.tgz"
artifact := &structs.TaskArtifact{
GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
GetterOptions: map[string]string{
"checksum": "sha1:e892194748ecbad5d0f60c6c6b2db2bdaa384a90",
},
}
getter := TestDefaultGetter(t)
require.NoError(t, getter.GetArtifact(noopTaskEnv(taskDir), artifact))
var expected map[string]int
if runtime.GOOS == "windows" {
// windows doesn't support Chmod changing file permissions.
expected = map[string]int{
"public": 0666,
"private": 0666,
"setuid": 0666,
}
} else {
// Verify the unarchiving masked files properly.
expected = map[string]int{
"public": 0666,
"private": 0600,
"setuid": 0755,
}
}
for file, perm := range expected {
path := filepath.Join(taskDir, "setuid", file)
s, err := os.Stat(path)
require.NoError(t, err)
p := os.FileMode(perm)
o := s.Mode()
require.Equalf(t, p, o, "%s expected %o found %o", file, p, o)
}
}
// TestGetArtifact_handlePanic tests that a panic during the getter execution
// does not cause its goroutine to crash.
func TestGetArtifact_handlePanic(t *testing.T) {
getter := TestDefaultGetter(t)
err := getter.GetArtifact(panicTaskEnv(), &structs.TaskArtifact{})
require.Error(t, err)
require.Contains(t, err.Error(), "panic")
}
func TestGetGetterUrl_Queries(t *testing.T) {
cases := []struct {
name string
artifact *structs.TaskArtifact
output string
}{
{
name: "adds query parameters",
artifact: &structs.TaskArtifact{
GetterSource: "https://foo.com?test=1",
GetterOptions: map[string]string{
"foo": "bar",
"bam": "boom",
},
},
output: "https://foo.com?bam=boom&foo=bar&test=1",
},
{
name: "git without http",
artifact: &structs.TaskArtifact{
GetterSource: "github.com/hashicorp/nomad",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "github.com/hashicorp/nomad?ref=abcd1234",
},
{
name: "git using ssh",
artifact: &structs.TaskArtifact{
GetterSource: "git@github.com:hashicorp/nomad?sshkey=1",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "git@github.com:hashicorp/nomad?ref=abcd1234&sshkey=1",
},
{
name: "s3 scheme 1",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 2",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 3",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3.amazonaws.com/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3.amazonaws.com/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 4",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3-eu-west-1.amazonaws.com/foo/bar",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3-eu-west-1.amazonaws.com/foo/bar?aws_access_key_id=abcd1234",
},
{
name: "gcs",
artifact: &structs.TaskArtifact{
GetterSource: "gcs::https://www.googleapis.com/storage/v1/b/d/f",
},
output: "gcs::https://www.googleapis.com/storage/v1/b/d/f",
},
{
name: "local file",
artifact: &structs.TaskArtifact{
GetterSource: "/foo/bar",
},
output: "/foo/bar",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
act, err := getGetterUrl(noopTaskEnv(""), c.artifact)
if err != nil {
t.Fatalf("want %q; got err %v", c.output, err)
} else if act != c.output {
t.Fatalf("want %q; got %q", c.output, act)
}
})
}
}