From 29d9b831d3a59dc9154e6b4c7df3983a0477bd14 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:40:11 -0700 Subject: [PATCH] Update the api for serving plugins and provide a utility to pass TLS data for commuinicating with the vault process --- builtin/logical/database/backend_test.go | 68 ++++-- .../logical/database/dbplugin/plugin_test.go | 55 +++-- builtin/logical/database/dbplugin/server.go | 13 +- helper/pluginutil/runner.go | 42 ++++ helper/pluginutil/tls.go | 199 +++++++++--------- .../cassandra-database-plugin/main.go | 7 +- plugins/database/cassandra/cassandra.go | 6 +- .../mssql/mssql-database-plugin/main.go | 7 +- plugins/database/mssql/mssql.go | 6 +- .../mysql/mysql-database-plugin/main.go | 7 +- plugins/database/mysql/mysql.go | 6 +- .../postgresql-database-plugin/main.go | 7 +- plugins/database/postgresql/postgresql.go | 6 +- plugins/serve.go | 31 +++ vault/testing.go | 3 + 15 files changed, 310 insertions(+), 153 deletions(-) create mode 100644 plugins/serve.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 08317cbdc..70ec22ee2 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -4,12 +4,13 @@ import ( "database/sql" "fmt" "log" - "net" + stdhttp "net/http" "os" "reflect" "sync" "testing" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -77,13 +78,30 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac return } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { - core, _, token, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "database": Factory, + }, + } - return core, ln, sys, token + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) + + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain") + + return cores, sys } func TestBackend_PluginMain(t *testing.T) { @@ -91,14 +109,20 @@ func TestBackend_PluginMain(t *testing.T) { return } - postgresql.Run() + err := postgresql.Run(&api.TLSConfig{Insecure: true}) + if err != nil { + t.Fatal(err) + } + t.Fatal("We shouldn't get here") } func TestBackend_config_connection(t *testing.T) { var resp *logical.Response var err error - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -147,8 +171,10 @@ func TestBackend_config_connection(t *testing.T) { } func TestBackend_basic(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -238,8 +264,10 @@ func TestBackend_basic(t *testing.T) { } func TestBackend_connectionCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -383,8 +411,10 @@ func TestBackend_connectionCrud(t *testing.T) { } func TestBackend_roleCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -493,8 +523,10 @@ func TestBackend_roleCrud(t *testing.T) { } } func TestBackend_allowedRoles(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 1587ba24a..c38d85ed3 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -2,15 +2,17 @@ package dbplugin_test import ( "errors" - "net" + stdhttp "net/http" "os" "testing" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/vault" log "github.com/mgutz/logxi/v1" ) @@ -72,13 +74,26 @@ func (m *mockPlugin) Close() error { return nil } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { - core, _, _, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{} - return core, ln, sys + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) + + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main") + + return cores, sys } // This is not an actual test case, it's a helper function that will be executed @@ -92,12 +107,14 @@ func TestPlugin_Main(t *testing.T) { users: make(map[string][]string), } - dbplugin.NewPluginServer(plugin) + plugins.Serve(plugin, &api.TLSConfig{Insecure: true}) } func TestPlugin_Initialize(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -120,8 +137,10 @@ func TestPlugin_Initialize(t *testing.T) { } func TestPlugin_CreateUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -155,8 +174,10 @@ func TestPlugin_CreateUser(t *testing.T) { } func TestPlugin_RenewUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -184,8 +205,10 @@ func TestPlugin_RenewUser(t *testing.T) { } func TestPlugin_RevokeUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 32c377e13..9546d092c 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,16 +1,15 @@ package dbplugin import ( - "fmt" + "crypto/tls" "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/pluginutil" ) // Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func Serve(db Database) { +func Serve(db Database, tlsProvider func() (*tls.Config, error)) { dbPlugin := &DatabasePlugin{ impl: db, } @@ -20,16 +19,10 @@ func Serve(db Database) { "database": dbPlugin, } - err := pluginutil.OptionallyEnableMlock() - if err != nil { - fmt.Println(err) - return - } - plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - TLSProvider: pluginutil.VaultPluginTLSProvider, + TLSProvider: tlsProvider, }) } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 0617f7624..91439a3b8 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -2,11 +2,13 @@ package pluginutil import ( "crypto/sha256" + "flag" "fmt" "os/exec" "time" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/wrapping" ) @@ -87,3 +89,43 @@ func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugi return client, nil } + +type APIClientMeta struct { + // These are set by the command line flags. + flagCACert string + flagCAPath string + flagClientCert string + flagClientKey string + flagInsecure bool +} + +func (f *APIClientMeta) FlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("tls settings", flag.ContinueOnError) + + fs.StringVar(&f.flagCACert, "ca-cert", "", "") + fs.StringVar(&f.flagCAPath, "ca-path", "", "") + fs.StringVar(&f.flagClientCert, "client-cert", "", "") + fs.StringVar(&f.flagClientKey, "client-key", "", "") + fs.BoolVar(&f.flagInsecure, "insecure", false, "") + fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") + + return fs +} + +func (f *APIClientMeta) GetTLSConfig() *api.TLSConfig { + // If we need custom TLS configuration, then set it + if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure { + t := &api.TLSConfig{ + CACert: f.flagCACert, + CAPath: f.flagCAPath, + ClientCert: f.flagClientCert, + ClientKey: f.flagClientKey, + TLSServerName: "", + Insecure: f.flagInsecure, + } + + return t + } + + return nil +} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 05804a33b..b355079d6 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -116,109 +116,114 @@ func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) ( // VaultPluginTLSProvider is run inside a plugin and retrives the response // wrapped TLS certificate from vault. It returns a configured TLS Config. -func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv(PluginUnwrapTokenEnv) +func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + unwrapToken := os.Getenv(PluginUnwrapTokenEnv) - // Ensure unwrap token is a JWT - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } + // Ensure unwrap token is a JWT + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } - // Parse the JWT and retrieve the vault address - wt, err := jws.ParseJWT([]byte(unwrapToken)) - if err != nil { - return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) - } - if wt == nil { - return nil, errors.New("nil decoded token") - } + // Parse the JWT and retrieve the vault address + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } - addrRaw := wt.Claims().Get("addr") - if addrRaw == nil { - return nil, errors.New("decoded token does not contain primary cluster address") - } - vaultAddr, ok := addrRaw.(string) - if !ok { - return nil, errors.New("decoded token's address not valid") - } - if vaultAddr == "" { - return nil, errors.New(`no address for the vault found`) - } + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } - // Sanity check the value - if _, err := url.Parse(vaultAddr); err != nil { - return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) - } + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } - // Unwrap the token - clientConf := api.DefaultConfig() - clientConf.Address = vaultAddr - client, err := api.NewClient(clientConf) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } + // Unwrap the token + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + if apiTLSConfig != nil { + clientConf.ConfigureTLS(apiTLSConfig) + } + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during api client creation: {{err}}", err) + } - secret, err := client.Logical().Unwrap(unwrapToken) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - if secret == nil { - return nil, errors.New("error during token unwrap request secret is nil") - } + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } - // Retrieve and parse the server's certificate - serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") + // Retrieve and parse the server's certificate + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Retrieve and parse the server's private key + serverKeyB64, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Add CA cert to the cert pool + caCertPool := x509.NewCertPool() + caCertPool.AddCert(serverCert) + + // Build a certificate object out of the server's cert and private key. + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil } - - serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Retrieve and parse the server's private key - serverKeyB64, ok := secret.Data["ServerKey"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Add CA cert to the cert pool - caCertPool := x509.NewCertPool() - caCertPool.AddCert(serverCert) - - // Build a certificate object out of the server's cert and private key. - cert := tls.Certificate{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverKey, - Leaf: serverCert, - } - - // Setup TLS config - tlsConfig := &tls.Config{ - ClientCAs: caCertPool, - RootCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - // TLS 1.2 minimum - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{cert}, - } - tlsConfig.BuildNameToCertificate() - - return tlsConfig, nil } diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go index 79f0e0dbe..bb3f44142 100644 --- a/plugins/database/cassandra/cassandra-database-plugin/main.go +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/cassandra" ) func main() { - err := cassandra.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := cassandra.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index bf1cbab92..60e445ff6 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -6,8 +6,10 @@ import ( "github.com/gocql/gocql" multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -41,13 +43,13 @@ func New() (interface{}, error) { } // Run instantiates a Cassandra object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.NewPluginServer(dbType.(*Cassandra)) + plugins.Serve(dbType.(*Cassandra), apiTLSConfig) return nil } diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go index ead1cf842..d52fd13db 100644 --- a/plugins/database/mssql/mssql-database-plugin/main.go +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mssql" ) func main() { - err := mssql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mssql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index d82efce6f..9b22aa87c 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -39,13 +41,13 @@ func New() (interface{}, error) { } // Run instantiates a MSSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MSSQL)) + plugins.Serve(dbType.(*MSSQL), apiTLSConfig) return nil } diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go index c0ec75c9c..a9389f504 100644 --- a/plugins/database/mysql/mysql-database-plugin/main.go +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mysql" ) func main() { - err := mysql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mysql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 7eb680759..7a44d7341 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -5,8 +5,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -42,13 +44,13 @@ func New() (interface{}, error) { } // Run instantiates a MySQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MySQL)) + plugins.Serve(dbType.(*MySQL), apiTLSConfig) return nil } diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index 9b9b813c4..e6acb0584 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/postgresql" ) func main() { - err := postgresql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := postgresql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index bc5b14544..d60ef8bbe 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -35,13 +37,13 @@ func New() (interface{}, error) { } // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*PostgreSQL)) + plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig) return nil } diff --git a/plugins/serve.go b/plugins/serve.go new file mode 100644 index 000000000..263b301f7 --- /dev/null +++ b/plugins/serve.go @@ -0,0 +1,31 @@ +package plugins + +import ( + "fmt" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// Serve is used to start a plugin's RPC server. It takes an interface that must +// implement a known plugin interface to vault and an optional api.TLSConfig for +// use during the inital unwrap request to vault. The api config is particulary +// useful when vault is setup to require client cert checking. +func Serve(plugin interface{}, tlsConfig *api.TLSConfig) { + tlsProvider := pluginutil.VaultPluginTLSProvider(tlsConfig) + + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return + } + + switch p := plugin.(type) { + case dbplugin.Database: + dbplugin.Serve(p, tlsProvider) + default: + fmt.Println("Unsuported plugin type") + } + +} diff --git a/vault/testing.go b/vault/testing.go index b2fe36b33..36bbb1276 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -790,6 +790,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c1.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -799,6 +800,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -808,6 +810,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr // // Clustering setup