diff --git a/command/commands.go b/command/commands.go index a934fd292..004c8d9d7 100644 --- a/command/commands.go +++ b/command/commands.go @@ -8,6 +8,7 @@ import ( "strings" "syscall" + "github.com/hashicorp/consul/command/join" "github.com/hashicorp/consul/command/validate" "github.com/hashicorp/consul/version" "github.com/mitchellh/cli" @@ -108,12 +109,7 @@ func init() { }, "join": func() (cli.Command, error) { - return &JoinCommand{ - BaseCommand: BaseCommand{ - Flags: FlagSetClientHTTP, - UI: ui, - }, - }, nil + return join.New(ui), nil }, "keygen": func() (cli.Command, error) { diff --git a/command/join.go b/command/join.go deleted file mode 100644 index ad03fca3c..000000000 --- a/command/join.go +++ /dev/null @@ -1,75 +0,0 @@ -package command - -import ( - "fmt" -) - -// JoinCommand is a Command implementation that tells a running Consul -// agent to join another. -type JoinCommand struct { - BaseCommand - - // flags - wan bool -} - -func (c *JoinCommand) initFlags() { - c.InitFlagSet() - c.FlagSet.BoolVar(&c.wan, "wan", false, - "Joins a server to another server in the WAN pool.") -} - -func (c *JoinCommand) Run(args []string) int { - c.initFlags() - if err := c.FlagSet.Parse(args); err != nil { - return 1 - } - - addrs := c.FlagSet.Args() - if len(addrs) == 0 { - c.UI.Error("At least one address to join must be specified.") - c.UI.Error("") - c.UI.Error(c.Help()) - return 1 - } - - client, err := c.HTTPClient() - if err != nil { - c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) - return 1 - } - - joins := 0 - for _, addr := range addrs { - err := client.Agent().Join(addr, c.wan) - if err != nil { - c.UI.Error(fmt.Sprintf("Error joining address '%s': %s", addr, err)) - } else { - joins++ - } - } - - if joins == 0 { - c.UI.Error("Failed to join any nodes.") - return 1 - } - - c.UI.Output(fmt.Sprintf( - "Successfully joined cluster by contacting %d nodes.", joins)) - return 0 -} - -func (c *JoinCommand) Help() string { - c.initFlags() - return c.HelpCommand(` -Usage: consul join [options] address ... - - Tells a running Consul agent (with "consul agent") to join the cluster - by specifying at least one existing member. - -`) -} - -func (c *JoinCommand) Synopsis() string { - return "Tell Consul agent to join cluster" -} diff --git a/command/join/join.go b/command/join/join.go new file mode 100644 index 000000000..f08163431 --- /dev/null +++ b/command/join/join.go @@ -0,0 +1,82 @@ +package join + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.initFlags() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + wan bool +} + +func (c *cmd) initFlags() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.wan, "wan", false, + "Joins a server to another server in the WAN pool.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + addrs := c.flags.Args() + if len(addrs) == 0 { + c.UI.Error("At least one address to join must be specified.") + c.UI.Error("") + c.UI.Error(c.Help()) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + joins := 0 + for _, addr := range addrs { + err := client.Agent().Join(addr, c.wan) + if err != nil { + c.UI.Error(fmt.Sprintf("Error joining address '%s': %s", addr, err)) + } else { + joins++ + } + } + + if joins == 0 { + c.UI.Error("Failed to join any nodes.") + return 1 + } + + c.UI.Output(fmt.Sprintf("Successfully joined cluster by contacting %d nodes.", joins)) + return 0 +} + +func (c *cmd) Synopsis() string { + return "Tell Consul agent to join cluster" +} + +func (c *cmd) Help() string { + s := `Usage: consul join [options] address ... + + Tells a running Consul agent (with "consul agent") to join the cluster + by specifying at least one existing member.` + + return flags.Usage(s, c.flags, c.http.ClientFlags(), nil) +} diff --git a/command/join_test.go b/command/join/join_test.go similarity index 72% rename from command/join_test.go rename to command/join/join_test.go index 2fa77b642..0abb59199 100644 --- a/command/join_test.go +++ b/command/join/join_test.go @@ -1,4 +1,4 @@ -package command +package join import ( "strings" @@ -8,35 +8,27 @@ import ( "github.com/mitchellh/cli" ) -func testJoinCommand(t *testing.T) (*cli.MockUi, *JoinCommand) { - ui := cli.NewMockUi() - return ui, &JoinCommand{ - BaseCommand: BaseCommand{ - UI: ui, - Flags: FlagSetClientHTTP, - }, +func TestJoinCommand_noTabs(t *testing.T) { + if strings.ContainsRune(New(nil).Help(), '\t') { + t.Fatal("usage has tabs") } } -func TestJoinCommand_implements(t *testing.T) { - t.Parallel() - var _ cli.Command = &JoinCommand{} -} - -func TestJoinCommandRun(t *testing.T) { +func TestJoinCommandJoinRun_lan(t *testing.T) { t.Parallel() a1 := agent.NewTestAgent(t.Name(), ``) a2 := agent.NewTestAgent(t.Name(), ``) defer a1.Shutdown() defer a2.Shutdown() - ui, c := testJoinCommand(t) + ui := cli.NewMockUi() + cmd := New(ui) args := []string{ "-http-addr=" + a1.HTTPAddr(), a2.Config.SerfBindAddrLAN.String(), } - code := c.Run(args) + code := cmd.Run(args) if code != 0 { t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) } @@ -53,14 +45,15 @@ func TestJoinCommandRun_wan(t *testing.T) { defer a1.Shutdown() defer a2.Shutdown() - ui, c := testJoinCommand(t) + ui := cli.NewMockUi() + cmd := New(ui) args := []string{ "-http-addr=" + a1.HTTPAddr(), "-wan", a2.Config.SerfBindAddrWAN.String(), } - code := c.Run(args) + code := cmd.Run(args) if code != 0 { t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) } @@ -72,10 +65,11 @@ func TestJoinCommandRun_wan(t *testing.T) { func TestJoinCommandRun_noAddrs(t *testing.T) { t.Parallel() - ui, c := testJoinCommand(t) + ui := cli.NewMockUi() + cmd := New(ui) args := []string{"-http-addr=foo"} - code := c.Run(args) + code := cmd.Run(args) if code != 1 { t.Fatalf("bad: %d", code) }