Add Concurrent Download Support for artifacts (#11531)
* add concurrent download support - resolves #11244 * format imports * mark `wg.Done()` via `defer` * added tests for successful and failure cases and resolved some goleak * docs: add changelog for #11531 * test typo fixes and improvements Co-authored-by: Michael Schurter <mschurter@hashicorp.com>
This commit is contained in:
parent
010acce59f
commit
1ff8b5f759
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
client: Download up to 3 artifacts concurrently
|
||||
```
|
|
@ -3,6 +3,7 @@ package taskrunner
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
||||
|
@ -25,6 +26,41 @@ func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook {
|
|||
return h
|
||||
}
|
||||
|
||||
func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) {
|
||||
defer wg.Done()
|
||||
for artifact := range jobs {
|
||||
aid := artifact.Hash()
|
||||
if req.PreviousState[aid] != "" {
|
||||
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
|
||||
responseStateMutex.Lock()
|
||||
resp.State[aid] = req.PreviousState[aid]
|
||||
responseStateMutex.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
|
||||
//XXX add ctx to GetArtifact to allow cancelling long downloads
|
||||
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {
|
||||
|
||||
wrapped := structs.NewRecoverableError(
|
||||
fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
|
||||
true,
|
||||
)
|
||||
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
|
||||
|
||||
errorChannel <- herr
|
||||
continue
|
||||
}
|
||||
|
||||
// Mark artifact as downloaded to avoid re-downloading due to
|
||||
// retries caused by subsequent artifacts failing. Any
|
||||
// non-empty value works.
|
||||
responseStateMutex.Lock()
|
||||
resp.State[aid] = "1"
|
||||
responseStateMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (*artifactHook) Name() string {
|
||||
// Copied in client/state when upgrading from <0.9 schemas, so if you
|
||||
// change it here you also must change it there.
|
||||
|
@ -40,33 +76,48 @@ func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar
|
|||
// Initialize hook state to store download progress
|
||||
resp.State = make(map[string]string, len(req.Task.Artifacts))
|
||||
|
||||
// responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map
|
||||
responseStateMutex := &sync.Mutex{}
|
||||
|
||||
h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))
|
||||
|
||||
for _, artifact := range req.Task.Artifacts {
|
||||
aid := artifact.Hash()
|
||||
if req.PreviousState[aid] != "" {
|
||||
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
|
||||
resp.State[aid] = req.PreviousState[aid]
|
||||
continue
|
||||
// maxConcurrency denotes the number of workers that will download artifacts in parallel
|
||||
maxConcurrency := 3
|
||||
|
||||
// jobsChannel is a buffered channel which will have all the artifacts that needs to be processed
|
||||
jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency)
|
||||
|
||||
// errorChannel is also a buffered channel that will be used to signal errors
|
||||
errorChannel := make(chan error, maxConcurrency)
|
||||
|
||||
// create workers and process artifacts
|
||||
go func() {
|
||||
defer close(errorChannel)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < maxConcurrency; i++ {
|
||||
wg.Add(1)
|
||||
go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex)
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource)
|
||||
//XXX add ctx to GetArtifact to allow cancelling long downloads
|
||||
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {
|
||||
|
||||
wrapped := structs.NewRecoverableError(
|
||||
fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
|
||||
true,
|
||||
)
|
||||
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
|
||||
|
||||
return herr
|
||||
// Push all artifact requests to job channel
|
||||
go func() {
|
||||
defer close(jobsChannel)
|
||||
for _, artifact := range req.Task.Artifacts {
|
||||
jobsChannel <- artifact
|
||||
}
|
||||
}()
|
||||
|
||||
// Mark artifact as downloaded to avoid re-downloading due to
|
||||
// retries caused by subsequent artifacts failing. Any
|
||||
// non-empty value works.
|
||||
resp.State[aid] = "1"
|
||||
// Iterate over the errorChannel and if there is an error, store it to a variable for future return
|
||||
var err error
|
||||
for e := range errorChannel {
|
||||
err = e
|
||||
}
|
||||
|
||||
// once error channel is closed, we can check and return the error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp.Done = true
|
||||
|
|
|
@ -2,6 +2,7 @@ package taskrunner
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -73,11 +74,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
|
|||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory with 1 of the 2 artifacts
|
||||
srcdir, err := ioutil.TempDir("", "nomadtest-src")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, os.RemoveAll(srcdir))
|
||||
}()
|
||||
srcdir := t.TempDir()
|
||||
|
||||
// Only create one of the 2 artifacts to cause an error on first run.
|
||||
file1 := filepath.Join(srcdir, "foo.txt")
|
||||
|
@ -159,3 +156,212 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
|
|||
require.True(t, resp.Done)
|
||||
require.Len(t, resp.State, 2)
|
||||
}
|
||||
|
||||
// TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess asserts that the artifact hook
|
||||
// download multiple files concurrently. this is a successful test without any errors.
|
||||
func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory all 7 artifacts
|
||||
srcdir := t.TempDir()
|
||||
|
||||
numOfFiles := 7
|
||||
for i := 0; i < numOfFiles; i++ {
|
||||
file := filepath.Join(srcdir, fmt.Sprintf("file%d.txt", i))
|
||||
require.NoError(t, ioutil.WriteFile(file, []byte{byte(i)}, 0644))
|
||||
}
|
||||
|
||||
// Test server to serve the artifacts
|
||||
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
|
||||
defer ts.Close()
|
||||
|
||||
// Create the target directory.
|
||||
destdir, err := ioutil.TempDir("", "nomadtest-dest")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, os.RemoveAll(destdir))
|
||||
}()
|
||||
|
||||
req := &interfaces.TaskPrestartRequest{
|
||||
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
|
||||
TaskDir: &allocdir.TaskDir{Dir: destdir},
|
||||
Task: &structs.Task{
|
||||
Artifacts: []*structs.TaskArtifact{
|
||||
{
|
||||
GetterSource: ts.URL + "/file0.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file1.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file2.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file3.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file4.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file5.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file6.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := interfaces.TaskPrestartResponse{}
|
||||
|
||||
// start the hook
|
||||
err = artifactHook.Prestart(context.Background(), req, &resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Done)
|
||||
require.Len(t, resp.State, 7)
|
||||
require.Len(t, me.events, 1)
|
||||
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)
|
||||
|
||||
// Assert all files downloaded properly
|
||||
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 7)
|
||||
sort.Strings(files)
|
||||
require.Contains(t, files[0], "file0.txt")
|
||||
require.Contains(t, files[1], "file1.txt")
|
||||
require.Contains(t, files[2], "file2.txt")
|
||||
require.Contains(t, files[3], "file3.txt")
|
||||
require.Contains(t, files[4], "file4.txt")
|
||||
require.Contains(t, files[5], "file5.txt")
|
||||
require.Contains(t, files[6], "file6.txt")
|
||||
}
|
||||
|
||||
// TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure asserts that the artifact hook
|
||||
// download multiple files concurrently. first iteration will result in failure and
|
||||
// second iteration should succeed without downloading already downloaded files.
|
||||
func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory with 3 of the 4 artifacts
|
||||
srcdir := t.TempDir()
|
||||
|
||||
file1 := filepath.Join(srcdir, "file1.txt")
|
||||
require.NoError(t, ioutil.WriteFile(file1, []byte{'1'}, 0644))
|
||||
|
||||
file2 := filepath.Join(srcdir, "file2.txt")
|
||||
require.NoError(t, ioutil.WriteFile(file2, []byte{'2'}, 0644))
|
||||
|
||||
file3 := filepath.Join(srcdir, "file3.txt")
|
||||
require.NoError(t, ioutil.WriteFile(file3, []byte{'3'}, 0644))
|
||||
|
||||
// Test server to serve the artifacts
|
||||
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
|
||||
defer ts.Close()
|
||||
|
||||
// Create the target directory.
|
||||
destdir, err := ioutil.TempDir("", "nomadtest-dest")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, os.RemoveAll(destdir))
|
||||
}()
|
||||
|
||||
req := &interfaces.TaskPrestartRequest{
|
||||
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
|
||||
TaskDir: &allocdir.TaskDir{Dir: destdir},
|
||||
Task: &structs.Task{
|
||||
Artifacts: []*structs.TaskArtifact{
|
||||
{
|
||||
GetterSource: ts.URL + "/file0.txt", // this request will fail
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file1.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file2.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
{
|
||||
GetterSource: ts.URL + "/file3.txt",
|
||||
GetterMode: structs.GetterModeAny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := interfaces.TaskPrestartResponse{}
|
||||
|
||||
// On first run all files will be downloaded except file0.txt
|
||||
err = artifactHook.Prestart(context.Background(), req, &resp)
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, structs.IsRecoverable(err))
|
||||
require.Len(t, resp.State, 3)
|
||||
require.False(t, resp.Done)
|
||||
require.Len(t, me.events, 1)
|
||||
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)
|
||||
|
||||
// delete the downloaded files so that it'll error if it's downloaded again
|
||||
require.NoError(t, os.Remove(file1))
|
||||
require.NoError(t, os.Remove(file2))
|
||||
require.NoError(t, os.Remove(file3))
|
||||
|
||||
// create the missing file
|
||||
file0 := filepath.Join(srcdir, "file0.txt")
|
||||
require.NoError(t, ioutil.WriteFile(file0, []byte{'0'}, 0644))
|
||||
|
||||
// Mock TaskRunner by copying state from resp to req and reset resp.
|
||||
req.PreviousState = helper.CopyMapStringString(resp.State)
|
||||
|
||||
resp = interfaces.TaskPrestartResponse{}
|
||||
|
||||
// Retry the download and assert it succeeds
|
||||
err = artifactHook.Prestart(context.Background(), req, &resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Done)
|
||||
require.Len(t, resp.State, 4)
|
||||
|
||||
// Assert all files downloaded properly
|
||||
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
|
||||
require.NoError(t, err)
|
||||
sort.Strings(files)
|
||||
require.Contains(t, files[0], "file0.txt")
|
||||
require.Contains(t, files[1], "file1.txt")
|
||||
require.Contains(t, files[2], "file2.txt")
|
||||
require.Contains(t, files[3], "file3.txt")
|
||||
|
||||
// verify the file contents too, since files will also be created for failed downloads
|
||||
data0, err := ioutil.ReadFile(files[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data0, []byte{'0'})
|
||||
|
||||
data1, err := ioutil.ReadFile(files[1])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data1, []byte{'1'})
|
||||
|
||||
data2, err := ioutil.ReadFile(files[2])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data2, []byte{'2'})
|
||||
|
||||
data3, err := ioutil.ReadFile(files[3])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data3, []byte{'3'})
|
||||
|
||||
require.True(t, resp.Done)
|
||||
require.Len(t, resp.State, 4)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue