interpet the artifact source

This commit is contained in:
Alex Dadgar 2016-04-11 18:46:16 -07:00
parent 98bbb10217
commit dc63c24e59
4 changed files with 65 additions and 23 deletions

View file

@ -8,6 +8,7 @@ import (
"sync"
gg "github.com/hashicorp/go-getter"
"github.com/hashicorp/nomad/client/driver/env"
"github.com/hashicorp/nomad/nomad/structs"
)
@ -45,8 +46,9 @@ func getClient(src, dst string) *gg.Client {
}
// getGetterUrl returns the go-getter URL to download the artifact.
func getGetterUrl(artifact *structs.TaskArtifact) (string, error) {
u, err := url.Parse(artifact.GetterSource)
func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) (string, error) {
taskEnv.Build()
u, err := url.Parse(taskEnv.ReplaceEnv(artifact.GetterSource))
if err != nil {
return "", fmt.Errorf("failed to parse source URL %q: %v", artifact.GetterSource, err)
}
@ -54,15 +56,17 @@ func getGetterUrl(artifact *structs.TaskArtifact) (string, error) {
// Build the url
q := u.Query()
for k, v := range artifact.GetterOptions {
q.Add(k, v)
q.Add(k, taskEnv.ReplaceEnv(v))
}
u.RawQuery = q.Encode()
return u.String(), nil
}
// GetArtifact downloads an artifact into the specified task directory.
func GetArtifact(artifact *structs.TaskArtifact, taskDir string, logger *log.Logger) error {
url, err := getGetterUrl(artifact)
func GetArtifact(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact,
taskDir string, logger *log.Logger) error {
url, err := getGetterUrl(taskEnv, artifact)
if err != nil {
return err
}

View file

@ -12,6 +12,8 @@ import (
"strings"
"testing"
"github.com/hashicorp/nomad/client/driver/env"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
)
@ -37,8 +39,9 @@ func TestGetArtifact_FileAndChecksum(t *testing.T) {
}
// Download the artifact
taskEnv := env.NewTaskEnvironment(mock.Node())
logger := log.New(os.Stderr, "", log.LstdFlags)
if err := GetArtifact(artifact, taskDir, logger); err != nil {
if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}
@ -72,8 +75,9 @@ func TestGetArtifact_File_RelativeDest(t *testing.T) {
}
// Download the artifact
taskEnv := env.NewTaskEnvironment(mock.Node())
logger := log.New(os.Stderr, "", log.LstdFlags)
if err := GetArtifact(artifact, taskDir, logger); err != nil {
if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}
@ -83,6 +87,24 @@ func TestGetArtifact_File_RelativeDest(t *testing.T) {
}
}
func TestGetGetterUrl_Interprolation(t *testing.T) {
// Create the artifact
artifact := &structs.TaskArtifact{
GetterSource: "${NOMAD_META_ARTIFACT}",
}
url := "foo.com"
taskEnv := env.NewTaskEnvironment(mock.Node()).SetTaskMeta(map[string]string{"artifact": url})
act, err := getGetterUrl(taskEnv, artifact)
if err != nil {
t.Fatalf("getGetterUrl() failed: %v", err)
}
if act != url {
t.Fatalf("getGetterUrl() returned %q; want %q", act, url)
}
}
func TestGetArtifact_InvalidChecksum(t *testing.T) {
// Create the test server hosting the file to download
ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
@ -105,8 +127,9 @@ func TestGetArtifact_InvalidChecksum(t *testing.T) {
}
// Download the artifact and expect an error
taskEnv := env.NewTaskEnvironment(mock.Node())
logger := log.New(os.Stderr, "", log.LstdFlags)
if err := GetArtifact(artifact, taskDir, logger); err == nil {
if err := GetArtifact(taskEnv, artifact, taskDir, logger); err == nil {
t.Fatalf("GetArtifact should have failed")
}
}
@ -171,8 +194,9 @@ func TestGetArtifact_Archive(t *testing.T) {
},
}
taskEnv := env.NewTaskEnvironment(mock.Node())
logger := log.New(os.Stderr, "", log.LstdFlags)
if err := GetArtifact(artifact, taskDir, logger); err != nil {
if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil {
t.Fatalf("GetArtifact failed: %v", err)
}

View file

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/nomad/client/getter"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/client/driver/env"
cstructs "github.com/hashicorp/nomad/client/driver/structs"
)
@ -43,6 +44,7 @@ type TaskRunner struct {
restartTracker *RestartTracker
task *structs.Task
taskEnv *env.TaskEnvironment
updateCh chan *structs.Allocation
handle driver.DriverHandle
handleLock sync.Mutex
@ -188,18 +190,29 @@ func (r *TaskRunner) setState(state string, event *structs.TaskEvent) {
r.updater(r.task.Name, state, event)
}
// createDriver makes a driver for the task
func (r *TaskRunner) createDriver() (driver.Driver, error) {
// setTaskEnv sets the task environment. It returns an error if it could not be
// created.
func (r *TaskRunner) setTaskEnv() error {
taskEnv, err := driver.GetTaskEnv(r.ctx.AllocDir, r.config.Node, r.task, r.alloc)
if err != nil {
err = fmt.Errorf("failed to create driver '%s' for alloc %s: %v",
return err
}
r.taskEnv = taskEnv
return nil
}
// createDriver makes a driver for the task
func (r *TaskRunner) createDriver() (driver.Driver, error) {
if r.taskEnv == nil {
if err := r.setTaskEnv(); err != nil {
err := fmt.Errorf("failed to create driver '%s' for alloc %s: %v",
r.task.Driver, r.alloc.ID, err)
r.logger.Printf("[ERR] client: %s", err)
return nil, err
}
}
driverCtx := driver.NewDriverContext(r.task.Name, r.config, r.config.Node, r.logger, taskEnv)
driverCtx := driver.NewDriverContext(r.task.Name, r.config, r.config.Node, r.logger, r.taskEnv)
driver, err := driver.NewDriver(r.task.Driver, driverCtx)
if err != nil {
err = fmt.Errorf("failed to create driver '%s' for alloc %s: %v",
@ -223,6 +236,13 @@ func (r *TaskRunner) Run() {
return
}
if err := r.setTaskEnv(); err != nil {
r.setState(
structs.TaskStateDead,
structs.NewTaskEvent(structs.TaskDriverFailure).SetDriverError(err))
return
}
r.run()
return
}
@ -277,7 +297,7 @@ func (r *TaskRunner) run() {
}
for _, artifact := range r.task.Artifacts {
if err := getter.GetArtifact(artifact, taskDir, r.logger); err != nil {
if err := getter.GetArtifact(r.taskEnv, artifact, taskDir, r.logger); err != nil {
r.setState(structs.TaskStateDead,
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err))
r.restartTracker.SetStartError(cstructs.NewRecoverableError(err, true))

View file

@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"net/url"
"path/filepath"
"reflect"
"regexp"
@ -1980,11 +1979,6 @@ func (ta *TaskArtifact) Validate() error {
var mErr multierror.Error
if ta.GetterSource == "" {
mErr.Errors = append(mErr.Errors, fmt.Errorf("source must be specified"))
} else {
_, err := url.Parse(ta.GetterSource)
if err != nil {
mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid source URL %q: %v", ta.GetterSource, err))
}
}
// Verify the destination doesn't escape the tasks directory