Use materialized duration fields for driver config

This commit is contained in:
Mahmood Ali 2018-11-12 22:15:57 -05:00
parent 865419e756
commit 356c194acc
3 changed files with 71 additions and 27 deletions

View File

@ -152,16 +152,19 @@ type TaskConfig struct {
StartErrRecoverable bool `codec:"start_error_recoverable"`
// StartBlockFor specifies a duration in which to block before returning
StartBlockFor string `codec:"start_block_for"`
StartBlockFor string `codec:"start_block_for"`
startBlockForDuration time.Duration
// KillAfter is the duration after which the mock driver indicates the task
// has exited after getting the initial SIGINT signal
KillAfter string `codec:"kill_after"`
KillAfter string `codec:"kill_after"`
killAfterDuration time.Duration
// RunFor is the duration for which the fake task runs for. After this
// period the MockDriver responds to the task running indicating that the
// task has terminated
RunFor string `codec:"run_for"`
RunFor string `codec:"run_for"`
runForDuration time.Duration
// ExitCode is the exit code with which the MockDriver indicates the task
// has exited
@ -195,7 +198,8 @@ type TaskConfig struct {
StdoutRepeat int `codec:"stdout_repeat"`
// StdoutRepeatDur is the duration between repeated outputs.
StdoutRepeatDur string `codec:"stdout_repeat_duration"`
StdoutRepeatDur string `codec:"stdout_repeat_duration"`
stdoutRepeatDuration time.Duration
}
type MockTaskState struct {
@ -298,8 +302,21 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
return nil, nil, err
}
if driverConfig.StartBlockFor != "" {
time.Sleep(parseDuration(driverConfig.StartBlockFor))
var err error
if driverConfig.startBlockForDuration, err = parseDuration(driverConfig.StartBlockFor); err != nil {
return nil, nil, fmt.Errorf("start_block_for %v not a valid duration: %v", driverConfig.StartBlockFor, err)
}
if driverConfig.runForDuration, err = parseDuration(driverConfig.RunFor); err != nil {
return nil, nil, fmt.Errorf("run_for %v not a valid duration: %v", driverConfig.RunFor, err)
}
if driverConfig.stdoutRepeatDuration, err = parseDuration(driverConfig.StdoutRepeatDur); err != nil {
return nil, nil, fmt.Errorf("stdout_repeat_duration %v not a valid duration: %v", driverConfig.stdoutRepeatDuration, err)
}
if driverConfig.startBlockForDuration != 0 {
time.Sleep(driverConfig.startBlockForDuration)
}
if driverConfig.StartErr != "" {
@ -327,13 +344,13 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
h := &taskHandle{
taskConfig: cfg,
runFor: parseDuration(driverConfig.RunFor),
killAfter: parseDuration(driverConfig.KillAfter),
runFor: driverConfig.runForDuration,
killAfter: driverConfig.killAfterDuration,
exitCode: driverConfig.ExitCode,
exitSignal: driverConfig.ExitSignal,
stdoutString: driverConfig.StdoutString,
stdoutRepeat: driverConfig.StdoutRepeat,
stdoutRepeatDur: parseDuration(driverConfig.StdoutRepeatDur),
stdoutRepeatDur: driverConfig.stdoutRepeatDuration,
logger: d.logger.With("task_name", cfg.Name),
waitCh: make(chan struct{}),
killCh: killCtx.Done(),
@ -455,21 +472,3 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (*
}
return &res, nil
}
func parseDuration(s string) time.Duration {
if s == "" {
return time.Duration(0)
}
// check if it's an int64
if v, err := strconv.ParseInt(s, 10, 64); err == nil {
return time.Duration(v)
}
// try to parse it as duration
if v, err := time.ParseDuration(s); err == nil {
return v
}
panic(fmt.Errorf("value is not a duration: %v", s))
}

16
drivers/mock/utils.go Normal file
View File

@ -0,0 +1,16 @@
package mock
import (
"time"
)
// parseDuration parses a duration string, like time.ParseDuration
// but is empty string friendly, returns a zero time duration
func parseDuration(s string) (time.Duration, error) {
if s == "" {
return time.Duration(0), nil
}
// try to parse it as duration
return time.ParseDuration(s)
}

View File

@ -0,0 +1,29 @@
package mock
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestParseDuration(t *testing.T) {
t.Run("valid case", func(t *testing.T) {
v, err := parseDuration("10m")
require.NoError(t, err)
require.Equal(t, 10*time.Minute, v)
})
t.Run("invalid case", func(t *testing.T) {
v, err := parseDuration("10")
require.Error(t, err)
require.Equal(t, time.Duration(0), v)
})
t.Run("empty case", func(t *testing.T) {
v, err := parseDuration("")
require.NoError(t, err)
require.Equal(t, time.Duration(0), v)
})
}