Update the api for serving plugins and provide a utility to pass TLS data for commuinicating with the vault process
This commit is contained in:
parent
ca7ff89bcb
commit
29d9b831d3
|
@ -4,12 +4,13 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/http"
|
||||
|
@ -77,13 +78,30 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac
|
|||
return
|
||||
}
|
||||
|
||||
func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) {
|
||||
core, _, token, ln := vault.TestCoreUnsealedWithListener(t)
|
||||
http.TestServerWithListener(t, ln, "", core)
|
||||
sys := vault.TestDynamicSystemView(core)
|
||||
vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain")
|
||||
func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"database": Factory,
|
||||
},
|
||||
}
|
||||
|
||||
return core, ln, sys, token
|
||||
handler1 := stdhttp.NewServeMux()
|
||||
handler2 := stdhttp.NewServeMux()
|
||||
handler3 := stdhttp.NewServeMux()
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false)
|
||||
handler1.Handle("/", http.Handler(cores[0].Core))
|
||||
handler2.Handle("/", http.Handler(cores[1].Core))
|
||||
handler3.Handle("/", http.Handler(cores[2].Core))
|
||||
|
||||
core := cores[0]
|
||||
|
||||
sys := vault.TestDynamicSystemView(core.Core)
|
||||
vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain")
|
||||
|
||||
return cores, sys
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain(t *testing.T) {
|
||||
|
@ -91,14 +109,20 @@ func TestBackend_PluginMain(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
postgresql.Run()
|
||||
err := postgresql.Run(&api.TLSConfig{Insecure: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("We shouldn't get here")
|
||||
}
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
_, ln, sys, _ := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
@ -147,8 +171,10 @@ func TestBackend_config_connection(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
_, ln, sys, _ := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
@ -238,8 +264,10 @@ func TestBackend_basic(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend_connectionCrud(t *testing.T) {
|
||||
_, ln, sys, _ := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
@ -383,8 +411,10 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
_, ln, sys, _ := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
@ -493,8 +523,10 @@ func TestBackend_roleCrud(t *testing.T) {
|
|||
}
|
||||
}
|
||||
func TestBackend_allowedRoles(t *testing.T) {
|
||||
_, ln, sys, _ := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
|
|
@ -2,15 +2,17 @@ package dbplugin_test
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
@ -72,13 +74,26 @@ func (m *mockPlugin) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) {
|
||||
core, _, _, ln := vault.TestCoreUnsealedWithListener(t)
|
||||
http.TestServerWithListener(t, ln, "", core)
|
||||
sys := vault.TestDynamicSystemView(core)
|
||||
vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main")
|
||||
func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) {
|
||||
coreConfig := &vault.CoreConfig{}
|
||||
|
||||
return core, ln, sys
|
||||
handler1 := stdhttp.NewServeMux()
|
||||
handler2 := stdhttp.NewServeMux()
|
||||
handler3 := stdhttp.NewServeMux()
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false)
|
||||
handler1.Handle("/", http.Handler(cores[0].Core))
|
||||
handler2.Handle("/", http.Handler(cores[1].Core))
|
||||
handler3.Handle("/", http.Handler(cores[2].Core))
|
||||
|
||||
core := cores[0]
|
||||
|
||||
sys := vault.TestDynamicSystemView(core.Core)
|
||||
vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main")
|
||||
|
||||
return cores, sys
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
|
@ -92,12 +107,14 @@ func TestPlugin_Main(t *testing.T) {
|
|||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
dbplugin.NewPluginServer(plugin)
|
||||
plugins.Serve(plugin, &api.TLSConfig{Insecure: true})
|
||||
}
|
||||
|
||||
func TestPlugin_Initialize(t *testing.T) {
|
||||
_, ln, sys := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
|
@ -120,8 +137,10 @@ func TestPlugin_Initialize(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPlugin_CreateUser(t *testing.T) {
|
||||
_, ln, sys := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
|
@ -155,8 +174,10 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPlugin_RenewUser(t *testing.T) {
|
||||
_, ln, sys := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
|
@ -184,8 +205,10 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPlugin_RevokeUser(t *testing.T) {
|
||||
_, ln, sys := getCore(t)
|
||||
defer ln.Close()
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
// Serve is called from within a plugin and wraps the provided
|
||||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database) {
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
dbPlugin := &DatabasePlugin{
|
||||
impl: db,
|
||||
}
|
||||
|
@ -20,16 +19,10 @@ func Serve(db Database) {
|
|||
"database": dbPlugin,
|
||||
}
|
||||
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
TLSProvider: pluginutil.VaultPluginTLSProvider,
|
||||
TLSProvider: tlsProvider,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -2,11 +2,13 @@ package pluginutil
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
)
|
||||
|
||||
|
@ -87,3 +89,43 @@ func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugi
|
|||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type APIClientMeta struct {
|
||||
// These are set by the command line flags.
|
||||
flagCACert string
|
||||
flagCAPath string
|
||||
flagClientCert string
|
||||
flagClientKey string
|
||||
flagInsecure bool
|
||||
}
|
||||
|
||||
func (f *APIClientMeta) FlagSet() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("tls settings", flag.ContinueOnError)
|
||||
|
||||
fs.StringVar(&f.flagCACert, "ca-cert", "", "")
|
||||
fs.StringVar(&f.flagCAPath, "ca-path", "", "")
|
||||
fs.StringVar(&f.flagClientCert, "client-cert", "", "")
|
||||
fs.StringVar(&f.flagClientKey, "client-key", "", "")
|
||||
fs.BoolVar(&f.flagInsecure, "insecure", false, "")
|
||||
fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "")
|
||||
|
||||
return fs
|
||||
}
|
||||
|
||||
func (f *APIClientMeta) GetTLSConfig() *api.TLSConfig {
|
||||
// If we need custom TLS configuration, then set it
|
||||
if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure {
|
||||
t := &api.TLSConfig{
|
||||
CACert: f.flagCACert,
|
||||
CAPath: f.flagCAPath,
|
||||
ClientCert: f.flagClientCert,
|
||||
ClientKey: f.flagClientKey,
|
||||
TLSServerName: "",
|
||||
Insecure: f.flagInsecure,
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -116,109 +116,114 @@ func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (
|
|||
|
||||
// VaultPluginTLSProvider is run inside a plugin and retrives the response
|
||||
// wrapped TLS certificate from vault. It returns a configured TLS Config.
|
||||
func VaultPluginTLSProvider() (*tls.Config, error) {
|
||||
unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
|
||||
func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) {
|
||||
return func() (*tls.Config, error) {
|
||||
unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
|
||||
|
||||
// Ensure unwrap token is a JWT
|
||||
if strings.Count(unwrapToken, ".") != 2 {
|
||||
return nil, errors.New("Could not parse unwraptoken")
|
||||
}
|
||||
// Ensure unwrap token is a JWT
|
||||
if strings.Count(unwrapToken, ".") != 2 {
|
||||
return nil, errors.New("Could not parse unwraptoken")
|
||||
}
|
||||
|
||||
// Parse the JWT and retrieve the vault address
|
||||
wt, err := jws.ParseJWT([]byte(unwrapToken))
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error decoding token: %s", err))
|
||||
}
|
||||
if wt == nil {
|
||||
return nil, errors.New("nil decoded token")
|
||||
}
|
||||
// Parse the JWT and retrieve the vault address
|
||||
wt, err := jws.ParseJWT([]byte(unwrapToken))
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error decoding token: %s", err))
|
||||
}
|
||||
if wt == nil {
|
||||
return nil, errors.New("nil decoded token")
|
||||
}
|
||||
|
||||
addrRaw := wt.Claims().Get("addr")
|
||||
if addrRaw == nil {
|
||||
return nil, errors.New("decoded token does not contain primary cluster address")
|
||||
}
|
||||
vaultAddr, ok := addrRaw.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("decoded token's address not valid")
|
||||
}
|
||||
if vaultAddr == "" {
|
||||
return nil, errors.New(`no address for the vault found`)
|
||||
}
|
||||
addrRaw := wt.Claims().Get("addr")
|
||||
if addrRaw == nil {
|
||||
return nil, errors.New("decoded token does not contain primary cluster address")
|
||||
}
|
||||
vaultAddr, ok := addrRaw.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("decoded token's address not valid")
|
||||
}
|
||||
if vaultAddr == "" {
|
||||
return nil, errors.New(`no address for the vault found`)
|
||||
}
|
||||
|
||||
// Sanity check the value
|
||||
if _, err := url.Parse(vaultAddr); err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err))
|
||||
}
|
||||
// Sanity check the value
|
||||
if _, err := url.Parse(vaultAddr); err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err))
|
||||
}
|
||||
|
||||
// Unwrap the token
|
||||
clientConf := api.DefaultConfig()
|
||||
clientConf.Address = vaultAddr
|
||||
client, err := api.NewClient(clientConf)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||
}
|
||||
// Unwrap the token
|
||||
clientConf := api.DefaultConfig()
|
||||
clientConf.Address = vaultAddr
|
||||
if apiTLSConfig != nil {
|
||||
clientConf.ConfigureTLS(apiTLSConfig)
|
||||
}
|
||||
client, err := api.NewClient(clientConf)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during api client creation: {{err}}", err)
|
||||
}
|
||||
|
||||
secret, err := client.Logical().Unwrap(unwrapToken)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||
}
|
||||
if secret == nil {
|
||||
return nil, errors.New("error during token unwrap request secret is nil")
|
||||
}
|
||||
secret, err := client.Logical().Unwrap(unwrapToken)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||
}
|
||||
if secret == nil {
|
||||
return nil, errors.New("error during token unwrap request secret is nil")
|
||||
}
|
||||
|
||||
// Retrieve and parse the server's certificate
|
||||
serverCertBytesRaw, ok := secret.Data["ServerCert"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
// Retrieve and parse the server's certificate
|
||||
serverCertBytesRaw, ok := secret.Data["ServerCert"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
}
|
||||
|
||||
serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(serverCertBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and parse the server's private key
|
||||
serverKeyB64, ok := secret.Data["ServerKey"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
}
|
||||
|
||||
serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverKey, err := x509.ParseECPrivateKey(serverKeyRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Add CA cert to the cert pool
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(serverCert)
|
||||
|
||||
// Build a certificate object out of the server's cert and private key.
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{serverCertBytes},
|
||||
PrivateKey: serverKey,
|
||||
Leaf: serverCert,
|
||||
}
|
||||
|
||||
// Setup TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: caCertPool,
|
||||
RootCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// TLS 1.2 minimum
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(serverCertBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and parse the server's private key
|
||||
serverKeyB64, ok := secret.Data["ServerKey"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
}
|
||||
|
||||
serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverKey, err := x509.ParseECPrivateKey(serverKeyRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Add CA cert to the cert pool
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(serverCert)
|
||||
|
||||
// Build a certificate object out of the server's cert and private key.
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{serverCertBytes},
|
||||
PrivateKey: serverKey,
|
||||
Leaf: serverCert,
|
||||
}
|
||||
|
||||
// Setup TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: caCertPool,
|
||||
RootCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// TLS 1.2 minimum
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
|
|
@ -4,11 +4,16 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/cassandra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := cassandra.Run()
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := cassandra.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
|
|
|
@ -6,8 +6,10 @@ import (
|
|||
|
||||
"github.com/gocql/gocql"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
|
@ -41,13 +43,13 @@ func New() (interface{}, error) {
|
|||
}
|
||||
|
||||
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.NewPluginServer(dbType.(*Cassandra))
|
||||
plugins.Serve(dbType.(*Cassandra), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,11 +4,16 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/mssql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := mssql.Run()
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := mssql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
|
|
|
@ -6,8 +6,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
|
@ -39,13 +41,13 @@ func New() (interface{}, error) {
|
|||
}
|
||||
|
||||
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.Serve(dbType.(*MSSQL))
|
||||
plugins.Serve(dbType.(*MSSQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,11 +4,16 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/mysql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := mysql.Run()
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := mysql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
|
|
|
@ -5,8 +5,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
|
@ -42,13 +44,13 @@ func New() (interface{}, error) {
|
|||
}
|
||||
|
||||
// Run instantiates a MySQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.Serve(dbType.(*MySQL))
|
||||
plugins.Serve(dbType.(*MySQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,11 +4,16 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := postgresql.Run()
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := postgresql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
|
|
|
@ -6,8 +6,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
|
@ -35,13 +37,13 @@ func New() (interface{}, error) {
|
|||
}
|
||||
|
||||
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.Serve(dbType.(*PostgreSQL))
|
||||
plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
31
plugins/serve.go
Normal file
31
plugins/serve.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
// Serve is used to start a plugin's RPC server. It takes an interface that must
|
||||
// implement a known plugin interface to vault and an optional api.TLSConfig for
|
||||
// use during the inital unwrap request to vault. The api config is particulary
|
||||
// useful when vault is setup to require client cert checking.
|
||||
func Serve(plugin interface{}, tlsConfig *api.TLSConfig) {
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch p := plugin.(type) {
|
||||
case dbplugin.Database:
|
||||
dbplugin.Serve(p, tlsProvider)
|
||||
default:
|
||||
fmt.Println("Unsuported plugin type")
|
||||
}
|
||||
|
||||
}
|
|
@ -790,6 +790,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
c1.redirectAddr = coreConfig.RedirectAddr
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
|
@ -799,6 +800,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
c2.redirectAddr = coreConfig.RedirectAddr
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
|
@ -808,6 +810,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
c2.redirectAddr = coreConfig.RedirectAddr
|
||||
|
||||
//
|
||||
// Clustering setup
|
||||
|
|
Loading…
Reference in a new issue