diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 68bfb0335..4282453ea 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -152,16 +152,19 @@ type TaskConfig struct { StartErrRecoverable bool `codec:"start_error_recoverable"` // StartBlockFor specifies a duration in which to block before returning - StartBlockFor string `codec:"start_block_for"` + StartBlockFor string `codec:"start_block_for"` + startBlockForDuration time.Duration // KillAfter is the duration after which the mock driver indicates the task // has exited after getting the initial SIGINT signal - KillAfter string `codec:"kill_after"` + KillAfter string `codec:"kill_after"` + killAfterDuration time.Duration // RunFor is the duration for which the fake task runs for. After this // period the MockDriver responds to the task running indicating that the // task has terminated - RunFor string `codec:"run_for"` + RunFor string `codec:"run_for"` + runForDuration time.Duration // ExitCode is the exit code with which the MockDriver indicates the task // has exited @@ -195,7 +198,8 @@ type TaskConfig struct { StdoutRepeat int `codec:"stdout_repeat"` // StdoutRepeatDur is the duration between repeated outputs. - StdoutRepeatDur string `codec:"stdout_repeat_duration"` + StdoutRepeatDur string `codec:"stdout_repeat_duration"` + stdoutRepeatDuration time.Duration } type MockTaskState struct { @@ -298,8 +302,21 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru return nil, nil, err } - if driverConfig.StartBlockFor != "" { - time.Sleep(parseDuration(driverConfig.StartBlockFor)) + var err error + if driverConfig.startBlockForDuration, err = parseDuration(driverConfig.StartBlockFor); err != nil { + return nil, nil, fmt.Errorf("start_block_for %v not a valid duration: %v", driverConfig.StartBlockFor, err) + } + + if driverConfig.runForDuration, err = parseDuration(driverConfig.RunFor); err != nil { + return nil, nil, fmt.Errorf("run_for %v not a valid duration: %v", driverConfig.RunFor, err) + } + + if driverConfig.stdoutRepeatDuration, err = parseDuration(driverConfig.StdoutRepeatDur); err != nil { + return nil, nil, fmt.Errorf("stdout_repeat_duration %v not a valid duration: %v", driverConfig.stdoutRepeatDuration, err) + } + + if driverConfig.startBlockForDuration != 0 { + time.Sleep(driverConfig.startBlockForDuration) } if driverConfig.StartErr != "" { @@ -327,13 +344,13 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru h := &taskHandle{ taskConfig: cfg, - runFor: parseDuration(driverConfig.RunFor), - killAfter: parseDuration(driverConfig.KillAfter), + runFor: driverConfig.runForDuration, + killAfter: driverConfig.killAfterDuration, exitCode: driverConfig.ExitCode, exitSignal: driverConfig.ExitSignal, stdoutString: driverConfig.StdoutString, stdoutRepeat: driverConfig.StdoutRepeat, - stdoutRepeatDur: parseDuration(driverConfig.StdoutRepeatDur), + stdoutRepeatDur: driverConfig.stdoutRepeatDuration, logger: d.logger.With("task_name", cfg.Name), waitCh: make(chan struct{}), killCh: killCtx.Done(), @@ -455,21 +472,3 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (* } return &res, nil } - -func parseDuration(s string) time.Duration { - if s == "" { - return time.Duration(0) - } - - // check if it's an int64 - if v, err := strconv.ParseInt(s, 10, 64); err == nil { - return time.Duration(v) - } - - // try to parse it as duration - if v, err := time.ParseDuration(s); err == nil { - return v - } - - panic(fmt.Errorf("value is not a duration: %v", s)) -} diff --git a/drivers/mock/utils.go b/drivers/mock/utils.go new file mode 100644 index 000000000..99ea2dbd4 --- /dev/null +++ b/drivers/mock/utils.go @@ -0,0 +1,16 @@ +package mock + +import ( + "time" +) + +// parseDuration parses a duration string, like time.ParseDuration +// but is empty string friendly, returns a zero time duration +func parseDuration(s string) (time.Duration, error) { + if s == "" { + return time.Duration(0), nil + } + + // try to parse it as duration + return time.ParseDuration(s) +} diff --git a/drivers/mock/utils_test.go b/drivers/mock/utils_test.go new file mode 100644 index 000000000..42458c978 --- /dev/null +++ b/drivers/mock/utils_test.go @@ -0,0 +1,29 @@ +package mock + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestParseDuration(t *testing.T) { + t.Run("valid case", func(t *testing.T) { + v, err := parseDuration("10m") + require.NoError(t, err) + require.Equal(t, 10*time.Minute, v) + }) + + t.Run("invalid case", func(t *testing.T) { + v, err := parseDuration("10") + require.Error(t, err) + require.Equal(t, time.Duration(0), v) + }) + + t.Run("empty case", func(t *testing.T) { + v, err := parseDuration("") + require.NoError(t, err) + require.Equal(t, time.Duration(0), v) + }) + +}