diff --git a/api/api.go b/api/api.go index 9506b4e1f..dae906c84 100644 --- a/api/api.go +++ b/api/api.go @@ -5,9 +5,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" + "os" "strconv" + "strings" "time" ) @@ -111,11 +114,17 @@ type Config struct { // DefaultConfig returns a default configuration for the client func DefaultConfig() *Config { - return &Config{ + config := &Config{ Address: "127.0.0.1:8500", Scheme: "http", HttpClient: http.DefaultClient, } + + if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 { + config.Address = os.Getenv("CONSUL_HTTP_ADDR") + } + + return config } // Client provides a client to the Consul API @@ -128,7 +137,11 @@ func NewClient(config *Config) (*Client, error) { // bootstrap the config defConfig := DefaultConfig() - if len(config.Address) == 0 { + switch { + case len(config.Address) != 0: + case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: + config.Address = os.Getenv("CONSUL_HTTP_ADDR") + default: config.Address = defConfig.Address } @@ -140,6 +153,16 @@ func NewClient(config *Config) (*Client, error) { config.HttpClient = defConfig.HttpClient } + if strings.HasPrefix(config.Address, "unix://") { + shortStr := strings.TrimPrefix(config.Address, "unix://") + t := &http.Transport{} + t.Dial = func(_, _ string) (net.Conn, error) { + return net.Dial("unix", shortStr) + } + config.HttpClient.Transport = t + config.Address = shortStr + } + client := &Client{ config: *config, } @@ -206,9 +229,6 @@ func (r *request) toHTTP() (*http.Request, error) { // Encode the query parameters r.url.RawQuery = r.params.Encode() - // Get the url sring - urlRaw := r.url.String() - // Check if we should encode the body if r.body == nil && r.obj != nil { if b, err := encodeBody(r.obj); err != nil { @@ -219,14 +239,21 @@ func (r *request) toHTTP() (*http.Request, error) { } // Create the HTTP request - req, err := http.NewRequest(r.method, urlRaw, r.body) + req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body) + if err != nil { + return nil, err + } + + req.URL.Host = r.url.Host + req.URL.Scheme = r.url.Scheme + req.Host = r.url.Host // Setup auth if err == nil && r.config.HttpAuth != nil { req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password) } - return req, err + return req, nil } // newRequest is used to create a new request diff --git a/command/agent/command.go b/command/agent/command.go index 82e111caa..c7da28f4d 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -295,13 +295,26 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - rpcListener, err := net.Listen("tcp", rpcAddr.String()) + if _, ok := rpcAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(rpcAddr.String()) + } + + rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) if err != nil { agent.Shutdown() c.Ui.Error(fmt.Sprintf("Error starting RPC listener: %s", err)) return err } + if _, ok := rpcAddr.(*net.UnixAddr); ok { + if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil { + agent.Shutdown() + c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err)) + return err + } + } + // Start the IPC layer c.Ui.Output("Starting Consul agent RPC...") c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) @@ -319,6 +332,7 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log if config.Ports.DNS > 0 { dnsAddr, err := config.ClientListener(config.Addresses.DNS, config.Ports.DNS) if err != nil { + agent.Shutdown() c.Ui.Error(fmt.Sprintf("Invalid DNS bind address: %s", err)) return err } diff --git a/command/agent/config.go b/command/agent/config.go index 3e1413492..2e077115f 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -7,8 +7,10 @@ import ( "io" "net" "os" + "os/user" "path/filepath" "sort" + "strconv" "strings" "time" @@ -345,6 +347,82 @@ type Config struct { WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` } +// UnixSocket contains the parameters for a Unix socket interface +type UnixSocket struct { + // Path to the socket on-disk + Path string + + // uid of the owner of the socket + Uid int + + // gid of the group of the socket + Gid int + + // Permissions for the socket file + Permissions os.FileMode +} + +func populateUnixSocket(addr string) (*UnixSocket, error) { + if !strings.HasPrefix(addr, "unix://") { + return nil, fmt.Errorf("Failed to parse Unix address, format is [path];[user];[group];[mode]: %v", addr) + } + + splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";") + if len(splitAddr) != 4 { + return nil, fmt.Errorf("Failed to parse Unix address, format is [path];[user];[group];[mode]: %v", addr) + } + + ret := &UnixSocket{Path: splitAddr[0]} + + if userVal, err := user.Lookup(splitAddr[1]); err != nil { + return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1]) + } else { + if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil { + return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid) + } else { + ret.Uid = int(uid64) + } + } + + // Go doesn't currently have a way to look up gid from group name, + // so require a numeric gid; see + // https://codereview.appspot.com/101310044 + if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil { + return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2]) + } else { + ret.Gid = int(gid64) + } + + if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil { + return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3]) + } else { + if mode > 0777 { + return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777") + } else { + ret.Permissions = os.FileMode(mode) + } + } + + return ret, nil +} + +func adjustUnixSocketPermissions(addr string) error { + sock, err := populateUnixSocket(addr) + if err != nil { + return err + } + + if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil { + return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err) + } + + if err = os.Chmod(sock.Path, sock.Permissions); err != nil { + return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err) + } + + return nil +} + type dirEnts []os.FileInfo // DefaultConfig is used to return a sane default configuration @@ -389,18 +467,30 @@ func (c *Config) EncryptBytes() ([]byte, error) { // ClientListener is used to format a listener for a // port on a ClientAddr -func (c *Config) ClientListener(override string, port int) (*net.TCPAddr, error) { +func (c *Config) ClientListener(override string, port int) (net.Addr, error) { var addr string if override != "" { addr = override } else { addr = c.ClientAddr } - ip := net.ParseIP(addr) - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP: %v", addr) + + switch { + case strings.HasPrefix(addr, "unix://"): + sock, err := populateUnixSocket(addr) + if err != nil { + return nil, err + } + + return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil + + default: + ip := net.ParseIP(addr) + if ip == nil { + return nil, fmt.Errorf("Failed to parse IP: %v", addr) + } + return &net.TCPAddr{IP: ip, Port: port}, nil } - return &net.TCPAddr{IP: ip, Port: port}, nil } // ClientListenerAddr is used to format an address for a @@ -410,8 +500,11 @@ func (c *Config) ClientListenerAddr(override string, port int) (string, error) { if err != nil { return "", err } - if addr.IP.IsUnspecified() { - addr.IP = net.ParseIP("127.0.0.1") + + if ipAddr, ok := addr.(*net.TCPAddr); ok { + if ipAddr.IP.IsUnspecified() { + ipAddr.IP = net.ParseIP("127.0.0.1") + } } return addr.String(), nil } diff --git a/command/agent/http.go b/command/agent/http.go index 3fe2feec5..d480de816 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/pprof" + "os" "strconv" "strings" "time" @@ -34,7 +35,7 @@ type HTTPServer struct { func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPServer, error) { var tlsConfig *tls.Config var list net.Listener - var httpAddr *net.TCPAddr + var httpAddr net.Addr var err error var servers []*HTTPServer @@ -58,12 +59,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, err } - ln, err := net.Listen("tcp", httpAddr.String()) - if err != nil { - return nil, err + if _, ok := httpAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(httpAddr.String()) } - list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) + ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) + if err != nil { + return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) + } + + switch httpAddr.(type) { + case *net.UnixAddr: + if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil { + return nil, err + } + list = tls.NewListener(ln, tlsConfig) + + case *net.TCPAddr: + list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) + + default: + return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) + } // Create the mux mux := http.NewServeMux() @@ -90,13 +108,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) } - // Create non-TLS listener - ln, err := net.Listen("tcp", httpAddr.String()) + if _, ok := httpAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(httpAddr.String()) + } + + ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) if err != nil { return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - list = tcpKeepAliveListener{ln.(*net.TCPListener)} + switch httpAddr.(type) { + case *net.UnixAddr: + if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil { + return nil, err + } + list = ln + + case *net.TCPAddr: + list = tcpKeepAliveListener{ln.(*net.TCPListener)} + + default: + return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) + } // Create the mux mux := http.NewServeMux() diff --git a/command/agent/rpc_client.go b/command/agent/rpc_client.go index 7ba1907b2..c769ef1d4 100644 --- a/command/agent/rpc_client.go +++ b/command/agent/rpc_client.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/logutils" "log" "net" + "os" + "strings" "sync" "sync/atomic" ) @@ -34,7 +36,7 @@ type seqHandler interface { type RPCClient struct { seq uint64 - conn *net.TCPConn + conn net.Conn reader *bufio.Reader writer *bufio.Writer dec *codec.Decoder @@ -79,8 +81,18 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error { // NewRPCClient is used to create a new RPC client given the address. // This will properly dial, handshake, and start listening func NewRPCClient(addr string) (*RPCClient, error) { + sanedAddr := os.Getenv("CONSUL_RPC_ADDR") + if len(sanedAddr) == 0 { + sanedAddr = addr + } + mode := "tcp" + if strings.HasPrefix(sanedAddr, "unix://") { + sanedAddr = strings.TrimPrefix(sanedAddr, "unix://") + mode = "unix" + } + // Try to dial to agent - conn, err := net.Dial("tcp", addr) + conn, err := net.Dial(mode, sanedAddr) if err != nil { return nil, err } @@ -88,7 +100,7 @@ func NewRPCClient(addr string) (*RPCClient, error) { // Create the client client := &RPCClient{ seq: 0, - conn: conn.(*net.TCPConn), + conn: conn, reader: bufio.NewReader(conn), writer: bufio.NewWriter(conn), dispatch: make(map[uint64]seqHandler), diff --git a/command/rpc.go b/command/rpc.go index f70fb4f23..f0c9e5b1f 100644 --- a/command/rpc.go +++ b/command/rpc.go @@ -8,8 +8,8 @@ import ( "github.com/hashicorp/consul/command/agent" ) -// RPCAddrEnvName defines the environment variable name, which can set -// a default RPC address in case there is no -rpc-addr specified. +// RPCAddrEnvName defines an environment variable name which sets +// an RPC address if there is no -rpc-addr specified. const RPCAddrEnvName = "CONSUL_RPC_ADDR" // RPCAddrFlag returns a pointer to a string that will be populated @@ -43,7 +43,12 @@ func HTTPClient(addr string) (*consulapi.Client, error) { // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { conf := consulapi.DefaultConfig() - conf.Address = addr + switch { + case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: + conf.Address = os.Getenv("CONSUL_HTTP_ADDR") + default: + conf.Address = addr + } conf.Datacenter = dc return consulapi.NewClient(conf) }