67 lines
1.9 KiB
Go
67 lines
1.9 KiB
Go
|
package taskrunner
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
|
||
|
log "github.com/hashicorp/go-hclog"
|
||
|
multierror "github.com/hashicorp/go-multierror"
|
||
|
"github.com/hashicorp/nomad/client/allocrunnerv2/interfaces"
|
||
|
"github.com/hashicorp/nomad/client/config"
|
||
|
"github.com/hashicorp/nomad/client/driver/env"
|
||
|
"github.com/hashicorp/nomad/nomad/structs"
|
||
|
)
|
||
|
|
||
|
// validateHook validates the task is able to be run.
|
||
|
type validateHook struct {
|
||
|
config *config.Config
|
||
|
logger log.Logger
|
||
|
}
|
||
|
|
||
|
func newValidateHook(config *config.Config, logger log.Logger) *validateHook {
|
||
|
h := &validateHook{
|
||
|
config: config,
|
||
|
}
|
||
|
h.logger = logger.Named(h.Name())
|
||
|
return h
|
||
|
}
|
||
|
|
||
|
func (*validateHook) Name() string {
|
||
|
return "validate"
|
||
|
}
|
||
|
|
||
|
func (h *validateHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
|
||
|
if err := validateTask(req.Task, req.TaskEnv, h.config); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
resp.Done = true
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func validateTask(task *structs.Task, taskEnv *env.TaskEnv, conf *config.Config) error {
|
||
|
var mErr multierror.Error
|
||
|
|
||
|
// Validate the user
|
||
|
unallowedUsers := conf.ReadStringListToMapDefault("user.blacklist", config.DefaultUserBlacklist)
|
||
|
checkDrivers := conf.ReadStringListToMapDefault("user.checked_drivers", config.DefaultUserCheckedDrivers)
|
||
|
if _, driverMatch := checkDrivers[task.Driver]; driverMatch {
|
||
|
if _, unallowed := unallowedUsers[task.User]; unallowed {
|
||
|
mErr.Errors = append(mErr.Errors, fmt.Errorf("running as user %q is disallowed", task.User))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Validate the Service names once they're interpolated
|
||
|
for i, service := range task.Services {
|
||
|
name := taskEnv.ReplaceEnv(service.Name)
|
||
|
if err := service.ValidateName(name); err != nil {
|
||
|
mErr.Errors = append(mErr.Errors, fmt.Errorf("service (%d) failed validation: %v", i, err))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(mErr.Errors) == 1 {
|
||
|
return mErr.Errors[0]
|
||
|
}
|
||
|
return mErr.ErrorOrNil()
|
||
|
}
|