diff --git a/command/server.go b/command/server.go index fc2c96647..e962f7a45 100644 --- a/command/server.go +++ b/command/server.go @@ -44,6 +44,8 @@ type ServerCommand struct { meta.Meta + logger *log.Logger + ReloadFuncs map[string][]server.ReloadFunc } @@ -136,7 +138,7 @@ func (c *ServerCommand) Run(args []string) int { // Create a logger. We wrap it in a gated writer so that it doesn't // start logging too early. logGate := &gatedwriter.Writer{Writer: os.Stderr} - logger := log.New(&logutils.LevelFilter{ + c.logger = log.New(&logutils.LevelFilter{ Levels: []logutils.LogLevel{ "TRACE", "DEBUG", "INFO", "WARN", "ERR"}, MinLevel: logutils.LogLevel(strings.ToUpper(logLevel)), @@ -150,7 +152,7 @@ func (c *ServerCommand) Run(args []string) int { // Initialize the backend backend, err := physical.NewBackend( - config.Backend.Type, logger, config.Backend.Config) + config.Backend.Type, c.logger, config.Backend.Config) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing backend of type %s: %s", @@ -179,7 +181,7 @@ func (c *ServerCommand) Run(args []string) int { AuditBackends: c.AuditBackends, CredentialBackends: c.CredentialBackends, LogicalBackends: c.LogicalBackends, - Logger: logger, + Logger: c.logger, DisableCache: config.DisableCache, DisableMlock: config.DisableMlock, MaxLeaseTTL: config.MaxLeaseTTL, @@ -190,7 +192,7 @@ func (c *ServerCommand) Run(args []string) int { var ok bool if config.HABackend != nil { habackend, err := physical.NewBackend( - config.HABackend.Type, logger, config.HABackend.Config) + config.HABackend.Type, c.logger, config.HABackend.Config) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing backend of type %s: %s", @@ -322,7 +324,7 @@ func (c *ServerCommand) Run(args []string) int { // Initialize the listeners lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { - ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config) + ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, logGate) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing listener of type %s: %s", @@ -351,6 +353,13 @@ func (c *ServerCommand) Run(args []string) int { } } + // Make sure we close all listeners from this point on + defer func() { + for _, ln := range lns { + ln.Close() + } + }() + infoKeys = append(infoKeys, "version") info["version"] = version.GetVersion().String() @@ -368,9 +377,6 @@ func (c *ServerCommand) Run(args []string) int { c.Ui.Output("") if verifyOnly { - for _, listener := range lns { - listener.Close() - } return 0 } @@ -410,10 +416,6 @@ func (c *ServerCommand) Run(args []string) int { } } - for _, listener := range lns { - listener.Close() - } - return 0 } diff --git a/command/server/config.go b/command/server/config.go index 1945040f4..1de86e17c 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -200,6 +200,7 @@ func ParseConfig(d string) (*Config, error) { } valid := []string{ + "atlas", "backend", "ha_backend", "listener", @@ -414,6 +415,8 @@ func parseHABackends(result *Config, list *ast.ObjectList) error { } func parseListeners(result *Config, list *ast.ObjectList) error { + var foundAtlas bool + listeners := make([]*Listener, 0, len(list.Items)) for _, item := range list.Items { key := "listener" @@ -423,10 +426,14 @@ func parseListeners(result *Config, list *ast.ObjectList) error { valid := []string{ "address", + "endpoint", + "infrastructure", + "node_id", "tls_disable", "tls_cert_file", "tls_key_file", "tls_min_version", + "token", } if err := checkHCLKeys(item.Val, valid); err != nil { return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) @@ -437,8 +444,27 @@ func parseListeners(result *Config, list *ast.ObjectList) error { return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) } + lnType := strings.ToLower(key) + + if lnType == "atlas" { + if foundAtlas { + return multierror.Prefix(fmt.Errorf("only one listener of type 'atlas' is permitted"), fmt.Sprintf("listeners.%s", key)) + } else { + foundAtlas = true + if m["token"] == "" { + return multierror.Prefix(fmt.Errorf("'token' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key)) + } + if m["infrastructure"] == "" { + return multierror.Prefix(fmt.Errorf("'infrastructure' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key)) + } + if m["node_id"] == "" { + return multierror.Prefix(fmt.Errorf("'node_id' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key)) + } + } + } + listeners = append(listeners, &Listener{ - Type: strings.ToLower(key), + Type: lnType, Config: m, }) } diff --git a/command/server/config_test.go b/command/server/config_test.go index 505e8fa08..8841990c0 100644 --- a/command/server/config_test.go +++ b/command/server/config_test.go @@ -15,6 +15,15 @@ func TestLoadConfigFile(t *testing.T) { expected := &Config{ Listeners: []*Listener{ + &Listener{ + Type: "atlas", + Config: map[string]string{ + "token": "foobar", + "infrastructure": "foo/bar", + "endpoint": "https://foo.bar:1111", + "node_id": "foo_node", + }, + }, &Listener{ Type: "tcp", Config: map[string]string{ @@ -72,6 +81,15 @@ func TestLoadConfigFile_json(t *testing.T) { "address": "127.0.0.1:443", }, }, + &Listener{ + Type: "atlas", + Config: map[string]string{ + "token": "foobar", + "infrastructure": "foo/bar", + "endpoint": "https://foo.bar:1111", + "node_id": "foo_node", + }, + }, }, Backend: &Backend{ diff --git a/command/server/listener.go b/command/server/listener.go index bbe9d18bb..051ab2e8b 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -6,17 +6,19 @@ import ( _ "crypto/sha512" "crypto/tls" "fmt" + "io" "net" "strconv" "sync" ) // ListenerFactory is the factory function to create a listener. -type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error) +type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ - "tcp": tcpListenerFactory, + "tcp": tcpListenerFactory, + "atlas": atlasListenerFactory, } // tlsLookup maps the tls_min_version configuration to the internal value @@ -28,13 +30,13 @@ var tlsLookup = map[string]uint16{ // NewListener creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { +func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { f, ok := BuiltinListeners[t] if !ok { return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) } - return f(config) + return f(config, logger) } func listenerWrapTLS( diff --git a/command/server/listener_atlas.go b/command/server/listener_atlas.go new file mode 100644 index 000000000..5d36c6c6c --- /dev/null +++ b/command/server/listener_atlas.go @@ -0,0 +1,60 @@ +package server + +import ( + "io" + "net" + + "github.com/hashicorp/scada-client/scada" + "github.com/hashicorp/vault/version" +) + +type SCADAListener struct { + ln net.Listener + scadaProvider *scada.Provider +} + +func (s *SCADAListener) Accept() (net.Conn, error) { + return s.ln.Accept() +} + +func (s *SCADAListener) Close() error { + s.scadaProvider.Shutdown() + return s.ln.Close() +} + +func (s *SCADAListener) Addr() net.Addr { + return s.ln.Addr() +} + +func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { + scadaConfig := &scada.Config{ + Service: "vault", + Version: version.GetVersion().String(), + ResourceType: "vault-cluster", + Meta: map[string]string{ + "node_id": config["node_id"], + }, + Atlas: scada.AtlasConfig{ + Endpoint: config["endpoint"], + Infrastructure: config["infrastructure"], + Token: config["token"], + }, + } + + provider, list, err := scada.NewHTTPProvider(scadaConfig, logger) + if err != nil { + return nil, nil, nil, err + } + + ln := &SCADAListener{ + ln: list, + scadaProvider: provider, + } + + props := map[string]string{ + "addr": "Atlas/SCADA", + "infrastructure": scadaConfig.Atlas.Infrastructure, + } + + return listenerWrapTLS(ln, props, config) +} diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index d4ba3aaff..c35126338 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -1,11 +1,12 @@ package server import ( + "io" "net" "time" ) -func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { +func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { addr, ok := config["address"] if !ok { addr = "127.0.0.1:8200" diff --git a/command/server/listener_tcp_test.go b/command/server/listener_tcp_test.go index 74b65ec69..7da203377 100644 --- a/command/server/listener_tcp_test.go +++ b/command/server/listener_tcp_test.go @@ -16,7 +16,7 @@ func TestTCPListener(t *testing.T) { ln, _, _, err := tcpListenerFactory(map[string]string{ "address": "127.0.0.1:0", "tls_disable": "1", - }) + }, nil) if err != nil { t.Fatalf("err: %s", err) } @@ -52,7 +52,7 @@ func TestTCPListener_tls(t *testing.T) { "address": "127.0.0.1:0", "tls_cert_file": wd + "reload_foo.pem", "tls_key_file": wd + "reload_foo.key", - }) + }, nil) if err != nil { t.Fatalf("err: %s", err) } diff --git a/command/server/test-fixtures/config.hcl b/command/server/test-fixtures/config.hcl index a591b838e..122710bf4 100644 --- a/command/server/test-fixtures/config.hcl +++ b/command/server/test-fixtures/config.hcl @@ -3,6 +3,13 @@ disable_mlock = true statsd_addr = "bar" statsite_addr = "foo" +listener "atlas" { + token = "foobar" + infrastructure = "foo/bar" + endpoint = "https://foo.bar:1111" + node_id = "foo_node" +} + listener "tcp" { address = "127.0.0.1:443" } diff --git a/command/server/test-fixtures/config.hcl.json b/command/server/test-fixtures/config.hcl.json index 02ab8eabe..094dc154a 100644 --- a/command/server/test-fixtures/config.hcl.json +++ b/command/server/test-fixtures/config.hcl.json @@ -1,17 +1,24 @@ { - "listener":{ - "tcp":{ - "address":"127.0.0.1:443" - } - }, - "backend":{ - "consul":{ - "foo":"bar" - } - }, - "telemetry":{ - "statsite_address":"baz" - }, - "max_lease_ttl":"10h", - "default_lease_ttl":"10h" + "listener": [{ + "tcp": { + "address": "127.0.0.1:443" + } + }, { + "atlas": { + "token": "foobar", + "infrastructure": "foo/bar", + "endpoint": "https://foo.bar:1111", + "node_id": "foo_node" + } + }], + "backend": { + "consul": { + "foo": "bar" + } + }, + "telemetry": { + "statsite_address": "baz" + }, + "max_lease_ttl": "10h", + "default_lease_ttl": "10h" }