package zk import ( "fmt" "io" "io/ioutil" "math/rand" "os" "path/filepath" "strings" "time" ) type TestServer struct { Port int Path string Srv *Server } type TestCluster struct { Path string Servers []TestServer } func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) { tmpPath, err := ioutil.TempDir("", "gozk") if err != nil { return nil, err } success := false startPort := int(rand.Int31n(6000) + 10000) cluster := &TestCluster{Path: tmpPath} defer func() { if !success { cluster.Stop() } }() for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN)) if err := os.Mkdir(srvPath, 0700); err != nil { return nil, err } port := startPort + serverN*3 cfg := ServerConfig{ ClientPort: port, DataDir: srvPath, } for i := 0; i < size; i++ { cfg.Servers = append(cfg.Servers, ServerConfigServer{ ID: i + 1, Host: "127.0.0.1", PeerPort: startPort + i*3 + 1, LeaderElectionPort: startPort + i*3 + 2, }) } cfgPath := filepath.Join(srvPath, "zoo.cfg") fi, err := os.Create(cfgPath) if err != nil { return nil, err } err = cfg.Marshall(fi) fi.Close() if err != nil { return nil, err } fi, err = os.Create(filepath.Join(srvPath, "myid")) if err != nil { return nil, err } _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() if err != nil { return nil, err } srv := &Server{ ConfigPath: cfgPath, Stdout: stdout, Stderr: stderr, } if err := srv.Start(); err != nil { return nil, err } cluster.Servers = append(cluster.Servers, TestServer{ Path: srvPath, Port: cfg.ClientPort, Srv: srv, }) } success = true time.Sleep(3 * time.Second) // Give the server time to become active. Should probably actually attempt to connect to verify. return cluster, nil } func (ts *TestCluster) Connect(idx int) (*Conn, error) { zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", ts.Servers[idx].Port)}, time.Second*15) return zk, err } func (ts *TestCluster) ConnectAll() (*Conn, <-chan Event, error) { return ts.ConnectAllTimeout(time.Second * 15) } func (ts *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) { hosts := make([]string, len(ts.Servers)) for i, srv := range ts.Servers { hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) } zk, ch, err := Connect(hosts, sessionTimeout) return zk, ch, err } func (ts *TestCluster) Stop() error { for _, srv := range ts.Servers { srv.Srv.Stop() } defer os.RemoveAll(ts.Path) return nil } func (tc *TestCluster) StartServer(server string) { for _, s := range tc.Servers { if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { s.Srv.Start() return } } panic(fmt.Sprintf("Unknown server: %s", server)) } func (tc *TestCluster) StopServer(server string) { for _, s := range tc.Servers { if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { s.Srv.Stop() return } } panic(fmt.Sprintf("Unknown server: %s", server)) }