diff --git a/command/services/register/config.go b/command/services/register/config.go index 0ab9df248..680e3f13d 100644 --- a/command/services/register/config.go +++ b/command/services/register/config.go @@ -1,16 +1,64 @@ package register import ( - "github.com/hashicorp/consul/agent/config" + "reflect" + "time" + + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/mitchellh/mapstructure" ) -// configToAgentService converts a ServiceDefinition struct to an +// serviceToAgentService converts a ServiceDefinition struct to an // AgentServiceRegistration API struct. -func configToAgentService(svc *config.ServiceDefinition) (*api.AgentServiceRegistration, error) { +func serviceToAgentService(svc *structs.ServiceDefinition) (*api.AgentServiceRegistration, error) { // mapstructure can do this for us, but we encapsulate it in this // helper function in case we need to change the logic in the future. var result api.AgentServiceRegistration - return &result, mapstructure.Decode(svc, &result) + d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &result, + DecodeHook: timeDurationToStringHookFunc(), + WeaklyTypedInput: true, + }) + if err != nil { + return nil, err + } + if err := d.Decode(svc); err != nil { + return nil, err + } + + // The structs version has non-pointer checks and the destination + // has pointers, so we need to set the destination to nil if there + // is no check ID set. + if result.Check != nil && result.Check.Name == "" { + result.Check = nil + } + if len(result.Checks) == 1 && result.Checks[0].Name == "" { + result.Checks = nil + } + + return &result, nil +} + +// timeDurationToStringHookFunc returns a DecodeHookFunc that converts +// time.Duration to string. +func timeDurationToStringHookFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + dur, ok := data.(time.Duration) + if !ok { + return data, nil + } + if t.Kind() != reflect.String { + return data, nil + } + if dur == 0 { + return "", nil + } + + // Convert it by parsing + return data.(time.Duration).String(), nil + } } diff --git a/command/services/register/config_test.go b/command/services/register/config_test.go index dc4ed3a71..767144912 100644 --- a/command/services/register/config_test.go +++ b/command/services/register/config_test.go @@ -3,23 +3,23 @@ package register import ( "testing" - "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/stretchr/testify/require" ) -func TestConfigToAgentService(t *testing.T) { +func TestStructsToAgentService(t *testing.T) { cases := []struct { Name string - Input *config.ServiceDefinition + Input *structs.ServiceDefinition Output *api.AgentServiceRegistration }{ { "Basic service with port", - &config.ServiceDefinition{ - Name: strPtr("web"), + &structs.ServiceDefinition{ + Name: "web", Tags: []string{"leader"}, - Port: intPtr(1234), + Port: 1234, }, &api.AgentServiceRegistration{ Name: "web", @@ -29,10 +29,10 @@ func TestConfigToAgentService(t *testing.T) { }, { "Service with a check", - &config.ServiceDefinition{ - Name: strPtr("web"), - Check: &config.CheckDefinition{ - Name: strPtr("ping"), + &structs.ServiceDefinition{ + Name: "web", + Check: structs.CheckType{ + Name: "ping", }, }, &api.AgentServiceRegistration{ @@ -44,14 +44,14 @@ func TestConfigToAgentService(t *testing.T) { }, { "Service with checks", - &config.ServiceDefinition{ - Name: strPtr("web"), - Checks: []config.CheckDefinition{ - config.CheckDefinition{ - Name: strPtr("ping"), + &structs.ServiceDefinition{ + Name: "web", + Checks: structs.CheckTypes{ + &structs.CheckType{ + Name: "ping", }, - config.CheckDefinition{ - Name: strPtr("pong"), + &structs.CheckType{ + Name: "pong", }, }, }, @@ -72,7 +72,7 @@ func TestConfigToAgentService(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { require := require.New(t) - actual, err := configToAgentService(tc.Input) + actual, err := serviceToAgentService(tc.Input) require.NoError(err) require.Equal(tc.Output, actual) }) diff --git a/command/services/register/register.go b/command/services/register/register.go index 0ceb24c40..1a754a640 100644 --- a/command/services/register/register.go +++ b/command/services/register/register.go @@ -2,8 +2,10 @@ package register import ( "flag" + "fmt" - //"github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/flags" "github.com/mitchellh/cli" ) @@ -48,25 +50,67 @@ func (c *cmd) Run(args []string) int { return 1 } - /* - ixns, err := c.ixnsFromArgs(args) - if err != nil { - c.UI.Error(fmt.Sprintf("Error: %s", err)) + svcs, err := c.svcsFromFiles(args) + if err != nil { + c.UI.Error(fmt.Sprintf("Error: %s", err)) + return 1 + } + + // Create and test the HTTP client + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Create all the services + for _, svc := range svcs { + if err := client.Agent().ServiceRegister(svc); err != nil { + c.UI.Error(fmt.Sprintf("Error registering service %q: %s", + svc.Name, err)) return 1 } - - // Create and test the HTTP client - /* - client, err := c.http.APIClient() - if err != nil { - c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) - return 1 - } - */ + } return 0 } +// svcsFromFiles loads service definitions from a set of configuration +// files and returns them. It will return an error if the configuration is +// invalid in any way. +func (c *cmd) svcsFromFiles(args []string) ([]*api.AgentServiceRegistration, error) { + // We set devMode to true so we can get the basic valid default + // configuration. devMode doesn't set any services by default so this + // is okay since we only look at services. + devMode := true + b, err := config.NewBuilder(config.Flags{ + ConfigFiles: args, + DevMode: &devMode, + }) + if err != nil { + return nil, err + } + + cfg, err := b.BuildAndValidate() + if err != nil { + return nil, err + } + + // The services are now in "structs.ServiceDefinition" form and we need + // them in "api.AgentServiceRegistration" form so do the conversion. + result := make([]*api.AgentServiceRegistration, 0, len(cfg.Services)) + for _, svc := range cfg.Services { + apiSvc, err := serviceToAgentService(svc) + if err != nil { + return nil, err + } + + result = append(result, apiSvc) + } + + return result, nil +} + func (c *cmd) Synopsis() string { return synopsis } diff --git a/command/services/register/register_test.go b/command/services/register/register_test.go new file mode 100644 index 000000000..f6681e119 --- /dev/null +++ b/command/services/register/register_test.go @@ -0,0 +1,73 @@ +package register + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/testutil" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestCommand_noTabs(t *testing.T) { + t.Parallel() + if strings.ContainsRune(New(nil).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestCommand_File(t *testing.T) { + t.Parallel() + + require := require.New(t) + a := agent.NewTestAgent(t.Name(), ``) + defer a.Shutdown() + client := a.Client() + + ui := cli.NewMockUi() + c := New(ui) + + contents := `{ "Service": { "Name": "web" } }` + f := testFile(t, "json") + defer os.Remove(f.Name()) + if _, err := f.WriteString(contents); err != nil { + t.Fatalf("err: %#v", err) + } + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + f.Name(), + } + + require.Equal(0, c.Run(args), ui.ErrorWriter.String()) + + svcs, err := client.Agent().Services() + require.NoError(err) + require.Len(svcs, 1) + + svc := svcs["web"] + require.NotNil(svc) +} + +func testFile(t *testing.T, suffix string) *os.File { + f := testutil.TempFile(t, "register-test-file") + if err := f.Close(); err != nil { + t.Fatalf("err: %s", err) + } + + newName := f.Name() + "." + suffix + if err := os.Rename(f.Name(), newName); err != nil { + os.Remove(f.Name()) + t.Fatalf("err: %s", err) + } + + f, err := os.Create(newName) + if err != nil { + os.Remove(newName) + t.Fatalf("err: %s", err) + } + + return f +}