Add Concurrent Download Support for artifacts (#11531)

* add concurrent download support - resolves #11244

* format imports

* mark `wg.Done()` via `defer`

* added tests for successful and failure cases and resolved some goleak

* docs: add changelog for #11531

* test typo fixes and improvements

Co-authored-by: Michael Schurter <mschurter@hashicorp.com>
This commit is contained in:
Gowtham 2022-04-20 22:45:56 +05:30 committed by GitHub
parent 010acce59f
commit 1ff8b5f759
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 286 additions and 26 deletions

3
.changelog/11531.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
client: Download up to 3 artifacts concurrently
```

View File

@ -3,6 +3,7 @@ package taskrunner
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
log "github.com/hashicorp/go-hclog" log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/interfaces"
@ -25,6 +26,41 @@ func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook {
return h return h
} }
func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) {
defer wg.Done()
for artifact := range jobs {
aid := artifact.Hash()
if req.PreviousState[aid] != "" {
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
responseStateMutex.Lock()
resp.State[aid] = req.PreviousState[aid]
responseStateMutex.Unlock()
continue
}
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
//XXX add ctx to GetArtifact to allow cancelling long downloads
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {
wrapped := structs.NewRecoverableError(
fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
true,
)
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
errorChannel <- herr
continue
}
// Mark artifact as downloaded to avoid re-downloading due to
// retries caused by subsequent artifacts failing. Any
// non-empty value works.
responseStateMutex.Lock()
resp.State[aid] = "1"
responseStateMutex.Unlock()
}
}
func (*artifactHook) Name() string { func (*artifactHook) Name() string {
// Copied in client/state when upgrading from <0.9 schemas, so if you // Copied in client/state when upgrading from <0.9 schemas, so if you
// change it here you also must change it there. // change it here you also must change it there.
@ -40,33 +76,48 @@ func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar
// Initialize hook state to store download progress // Initialize hook state to store download progress
resp.State = make(map[string]string, len(req.Task.Artifacts)) resp.State = make(map[string]string, len(req.Task.Artifacts))
// responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map
responseStateMutex := &sync.Mutex{}
h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts)) h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))
// maxConcurrency denotes the number of workers that will download artifacts in parallel
maxConcurrency := 3
// jobsChannel is a buffered channel which will have all the artifacts that needs to be processed
jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency)
// errorChannel is also a buffered channel that will be used to signal errors
errorChannel := make(chan error, maxConcurrency)
// create workers and process artifacts
go func() {
defer close(errorChannel)
var wg sync.WaitGroup
for i := 0; i < maxConcurrency; i++ {
wg.Add(1)
go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex)
}
wg.Wait()
}()
// Push all artifact requests to job channel
go func() {
defer close(jobsChannel)
for _, artifact := range req.Task.Artifacts { for _, artifact := range req.Task.Artifacts {
aid := artifact.Hash() jobsChannel <- artifact
if req.PreviousState[aid] != "" { }
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource) }()
resp.State[aid] = req.PreviousState[aid]
continue // Iterate over the errorChannel and if there is an error, store it to a variable for future return
var err error
for e := range errorChannel {
err = e
} }
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource) // once error channel is closed, we can check and return the error
//XXX add ctx to GetArtifact to allow cancelling long downloads if err != nil {
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil { return err
wrapped := structs.NewRecoverableError(
fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
true,
)
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
return herr
}
// Mark artifact as downloaded to avoid re-downloading due to
// retries caused by subsequent artifacts failing. Any
// non-empty value works.
resp.State[aid] = "1"
} }
resp.Done = true resp.Done = true

View File

@ -2,6 +2,7 @@ package taskrunner
import ( import (
"context" "context"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -73,11 +74,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
artifactHook := newArtifactHook(me, testlog.HCLogger(t)) artifactHook := newArtifactHook(me, testlog.HCLogger(t))
// Create a source directory with 1 of the 2 artifacts // Create a source directory with 1 of the 2 artifacts
srcdir, err := ioutil.TempDir("", "nomadtest-src") srcdir := t.TempDir()
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(srcdir))
}()
// Only create one of the 2 artifacts to cause an error on first run. // Only create one of the 2 artifacts to cause an error on first run.
file1 := filepath.Join(srcdir, "foo.txt") file1 := filepath.Join(srcdir, "foo.txt")
@ -159,3 +156,212 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
require.True(t, resp.Done) require.True(t, resp.Done)
require.Len(t, resp.State, 2) require.Len(t, resp.State, 2)
} }
// TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess asserts that the artifact hook
// download multiple files concurrently. this is a successful test without any errors.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) {
t.Parallel()
me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
// Create a source directory all 7 artifacts
srcdir := t.TempDir()
numOfFiles := 7
for i := 0; i < numOfFiles; i++ {
file := filepath.Join(srcdir, fmt.Sprintf("file%d.txt", i))
require.NoError(t, ioutil.WriteFile(file, []byte{byte(i)}, 0644))
}
// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()
// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file4.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file5.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file6.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}
resp := interfaces.TaskPrestartResponse{}
// start the hook
err = artifactHook.Prestart(context.Background(), req, &resp)
require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 7)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)
// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
require.Len(t, files, 7)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")
require.Contains(t, files[4], "file4.txt")
require.Contains(t, files[5], "file5.txt")
require.Contains(t, files[6], "file6.txt")
}
// TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure asserts that the artifact hook
// download multiple files concurrently. first iteration will result in failure and
// second iteration should succeed without downloading already downloaded files.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) {
t.Parallel()
me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
// Create a source directory with 3 of the 4 artifacts
srcdir := t.TempDir()
file1 := filepath.Join(srcdir, "file1.txt")
require.NoError(t, ioutil.WriteFile(file1, []byte{'1'}, 0644))
file2 := filepath.Join(srcdir, "file2.txt")
require.NoError(t, ioutil.WriteFile(file2, []byte{'2'}, 0644))
file3 := filepath.Join(srcdir, "file3.txt")
require.NoError(t, ioutil.WriteFile(file3, []byte{'3'}, 0644))
// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()
// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt", // this request will fail
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}
resp := interfaces.TaskPrestartResponse{}
// On first run all files will be downloaded except file0.txt
err = artifactHook.Prestart(context.Background(), req, &resp)
require.Error(t, err)
require.True(t, structs.IsRecoverable(err))
require.Len(t, resp.State, 3)
require.False(t, resp.Done)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)
// delete the downloaded files so that it'll error if it's downloaded again
require.NoError(t, os.Remove(file1))
require.NoError(t, os.Remove(file2))
require.NoError(t, os.Remove(file3))
// create the missing file
file0 := filepath.Join(srcdir, "file0.txt")
require.NoError(t, ioutil.WriteFile(file0, []byte{'0'}, 0644))
// Mock TaskRunner by copying state from resp to req and reset resp.
req.PreviousState = helper.CopyMapStringString(resp.State)
resp = interfaces.TaskPrestartResponse{}
// Retry the download and assert it succeeds
err = artifactHook.Prestart(context.Background(), req, &resp)
require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 4)
// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")
// verify the file contents too, since files will also be created for failed downloads
data0, err := ioutil.ReadFile(files[0])
require.NoError(t, err)
require.Equal(t, data0, []byte{'0'})
data1, err := ioutil.ReadFile(files[1])
require.NoError(t, err)
require.Equal(t, data1, []byte{'1'})
data2, err := ioutil.ReadFile(files[2])
require.NoError(t, err)
require.Equal(t, data2, []byte{'2'})
data3, err := ioutil.ReadFile(files[3])
require.NoError(t, err)
require.Equal(t, data3, []byte{'3'})
require.True(t, resp.Done)
require.Len(t, resp.State, 4)
}