diff --git a/command/agent/command_test.go b/command/agent/command_test.go index 983944698..9f9317e7b 100644 --- a/command/agent/command_test.go +++ b/command/agent/command_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "strings" "testing" @@ -299,3 +300,39 @@ func TestSetupScadaConn(t *testing.T) { t.Fatalf("should be closed") } } + +func TestProtectDataDir(t *testing.T) { + dir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir) + + if err := os.MkdirAll(filepath.Join(dir, "mdb"), 0700); err != nil { + t.Fatalf("err: %v", err) + } + + cfgFile, err := ioutil.TempFile("", "consul") + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.Remove(cfgFile.Name()) + + content := fmt.Sprintf(`{"server": true, "data_dir": "%s"}`, dir) + _, err = cfgFile.Write([]byte(content)) + if err != nil { + t.Fatalf("err: %v", err) + } + + ui := new(cli.MockUi) + cmd := &Command{ + Ui: ui, + args: []string{"-config-file=" + cfgFile.Name()}, + } + if conf := cmd.readConfig(); conf != nil { + t.Fatalf("should fail") + } + if out := ui.ErrorWriter.String(); !strings.Contains(out, dir) { + t.Fatalf("expected mdb dir error, got: %s", out) + } +}