From acdcd79af3e7ab392be133ee36036854e283251b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 19 Dec 2016 10:15:58 -0800 Subject: [PATCH 001/162] Begin work on database refactor --- builtin/logical/database/backend.go | 104 +++ builtin/logical/database/backend_test.go | 620 ++++++++++++++++++ builtin/logical/database/dbs/cassandra.go | 194 ++++++ builtin/logical/database/dbs/db.go | 56 ++ builtin/logical/database/dbs/postgresql.go | 336 ++++++++++ .../database/path_config_connection.go | 188 ++++++ builtin/logical/database/path_config_lease.go | 103 +++ builtin/logical/database/path_role_create.go | 120 ++++ builtin/logical/database/path_roles.go | 161 +++++ builtin/logical/database/secret_creds.go | 147 +++++ cli/commands.go | 2 + 11 files changed, 2031 insertions(+) create mode 100644 builtin/logical/database/backend.go create mode 100644 builtin/logical/database/backend_test.go create mode 100644 builtin/logical/database/dbs/cassandra.go create mode 100644 builtin/logical/database/dbs/db.go create mode 100644 builtin/logical/database/dbs/postgresql.go create mode 100644 builtin/logical/database/path_config_connection.go create mode 100644 builtin/logical/database/path_config_lease.go create mode 100644 builtin/logical/database/path_role_create.go create mode 100644 builtin/logical/database/path_roles.go create mode 100644 builtin/logical/database/secret_creds.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go new file mode 100644 index 000000000..8b7fa3670 --- /dev/null +++ b/builtin/logical/database/backend.go @@ -0,0 +1,104 @@ +package database + +import ( + "strings" + "sync" + + log "github.com/mgutz/logxi/v1" + + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(conf *logical.BackendConfig) (logical.Backend, error) { + return Backend(conf).Setup(conf) +} + +func Backend(conf *logical.BackendConfig) *databaseBackend { + var b databaseBackend + b.Backend = &framework.Backend{ + Help: strings.TrimSpace(backendHelp), + + Paths: []*framework.Path{ + pathConfigConnection(&b), + pathConfigLease(&b), + pathListRoles(&b), + pathRoles(&b), + pathRoleCreate(&b), + }, + + Secrets: []*framework.Secret{ + secretCreds(&b), + }, + + Clean: b.resetAllDBs, + } + + b.logger = conf.Logger + b.connections = make(map[string]dbs.DatabaseType) + return &b +} + +type databaseBackend struct { + connections map[string]dbs.DatabaseType + logger log.Logger + + *framework.Backend + sync.RWMutex +} + +// resetAllDBs closes all connections from all database types +func (b *databaseBackend) resetAllDBs() { + b.logger.Trace("postgres/resetdb: enter") + defer b.logger.Trace("postgres/resetdb: exit") + + b.Lock() + defer b.Unlock() + + for _, db := range b.connections { + db.Close() + } +} + +// Lease returns the lease information +func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { + entry, err := s.Get("config/lease") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result configLease + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { + entry, err := s.Get("role/" + n) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result roleEntry + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +const backendHelp = ` +The PostgreSQL backend dynamically generates database users. + +After mounting this backend, configure it using the endpoints within +the "config/" path. +` diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 000000000..a203c9b19 --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,620 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + "github.com/ory-am/dockertest" +) + +var ( + testImagePull sync.Once +) + +func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("postgres") + }) + + cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: map[string]interface{}{ + "connection_url": connURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return false and try again + return false + } + if resp == nil { + t.Fatal("expected warning") + } + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { + err := cid.KillRemove() + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "value": "", + "max_open_connections": 9, + "max_idle_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(configData, "verify_connection") + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadCreds(t, b, config.StorageView, "web", connURL), + }, + }) +} + +func TestBackend_roleCrud(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadRole(t, "web", testRole), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web", ""), + }, + }) +} + +func TestBackend_BlockStatements(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice) + if err != nil { + t.Fatal(err) + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + // This will also validate the query + testAccStepCreateRole(t, "web-block", testBlockStatementRole, true), + testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false), + }, + }) +} + +func TestBackend_roleReadOnly(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), + testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: d, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if expectError { + if resp.Data == nil { + return fmt.Errorf("data is nil") + } + var e struct { + Error string `mapstructure:"error"` + } + if err := mapstructure.Decode(resp.Data, &e); err != nil { + return err + } + if len(e.Error) == 0 { + return fmt.Errorf("expected error, but write succeeded.") + } + return nil + } else if resp != nil && resp.IsError() { + return fmt.Errorf("got an error response: %v", resp.Error()) + } + return nil + }, + } +} + +func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + }, + ErrorOk: expectFail, + } +} + +func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + "revocation_sql": revocationSQL, + }, + ErrorOk: expectFail, + } +} + +func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: path.Join("roles", name), + } +} + +func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + // minNumPermissions is the minimum number of permissions that will always be present. + const minNumPermissions = 2 + + userRows := returnedRows() + if userRows < minNumPermissions { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + "role": name, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + userRows = returnedRows() + // User shouldn't exist so returnedRows() should encounter an error and exit with -1 + if userRows != -1 { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + return nil + }, + } +} + +func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("DROP TABLE test;") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "roles/" + name, + Check: func(resp *logical.Response) error { + if resp == nil { + if sql == "" { + return nil + } + + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + SQL string `mapstructure:"sql"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.SQL != sql { + return fmt.Errorf("bad: %#v", resp) + } + + return nil + }, + } +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go new file mode 100644 index 000000000..8c7a068be --- /dev/null +++ b/builtin/logical/database/dbs/cassandra.go @@ -0,0 +1,194 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +type Cassandra struct { + // Session is goroutine safe, however, since we reinitialize + // it when connection info changes, we want to make sure we + // can close it and use a new connection; hence the lock + session *gocql.Session + config ConnectionConfig + + sync.RWMutex +} + +func (c *Cassandra) Type() string { + return cassandraTypeName +} + +func (c *Cassandra) Connection() (*gocql.Session, error) { + // Grab the write lock + c.Lock() + defer c.Unlock() + + // If we already have a DB, we got it! + if c.session != nil { + return c.session, nil + } + + session, err := createSession(c.config) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (p *Cassandra) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.session != nil { + p.session.Close() + } + + p.session = nil +} + +func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + + return nil +} + +func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go new file mode 100644 index 000000000..ee7b15b64 --- /dev/null +++ b/builtin/logical/database/dbs/db.go @@ -0,0 +1,56 @@ +package dbs + +import ( + "database/sql" + "errors" + "fmt" + "strings" +) + +const ( + postgreSQLTypeName = "postgres" + cassandraTypeName = "cassandra" +) + +var ( + ErrUnsupportedDatabaseType = errors.New("Unsupported database type") +) + +func Factory(conf ConnectionConfig) (DatabaseType, error) { + switch conf.ConnectionType { + case postgreSQLTypeName: + return &PostgreSQL{ + config: conf, + }, nil + } + + return nil, ErrUnsupportedDatabaseType +} + +type DatabaseType interface { + Type() string + Connection() (*sql.DB, error) + Close() + Reset(ConnectionConfig) (*sql.DB, error) + CreateUser(createStmt, username, password, expiration string) error + RenewUser(username, expiration string) error + CustomRevokeUser(username, revocationSQL string) error + DefaultRevokeUser(username string) error +} + +type ConnectionConfig struct { + ConnectionType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +} + +// Query templates a query for us. +func queryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go new file mode 100644 index 000000000..ea7d08f8a --- /dev/null +++ b/builtin/logical/database/dbs/postgresql.go @@ -0,0 +1,336 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" + "github.com/lib/pq" +) + +type PostgreSQL struct { + db *sql.DB + config ConnectionConfig + + sync.RWMutex +} + +func (p *PostgreSQL) Type() string { + return postgreSQLTypeName +} + +func (p *PostgreSQL) Connection() (*sql.DB, error) { + // Grab the write lock + p.Lock() + defer p.Unlock() + + // If we already have a DB, we got it! + if p.db != nil { + if err := p.db.Ping(); err == nil { + return p.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + p.db.Close() + } + + // Otherwise, attempt to make connection + conn := p.config.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + p.db, err = sql.Open("postgres", conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + p.db.SetMaxOpenConns(p.config.MaxOpenConnections) + p.db.SetMaxIdleConns(p.config.MaxIdleConnections) + + return p.db, nil +} + +func (p *PostgreSQL) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.db != nil { + p.db.Close() + } + + p.db = nil +} + +func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + // b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction") + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // b.logger.Trace("postgres/pathRoleCreateRead: preparing statement") + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expiration, + })) + if err != nil { + return err + } + defer stmt.Close() + // b.logger.Trace("postgres/pathRoleCreateRead: executing statement") + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + + // b.logger.Trace("postgres/pathRoleCreateRead: committing transaction") + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expiration) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go new file mode 100644 index 000000000..be017ea35 --- /dev/null +++ b/builtin/logical/database/path_config_connection.go @@ -0,0 +1,188 @@ +package database + +import ( + "fmt" + + "github.com/fatih/structs" + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathConfigConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + + "connection_type": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB type (e.g. postgres)", + }, + + "connection_url": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB connection string", + }, + + "connection_details": &framework.FieldSchema{ + Type: framework.TypeMap, + Description: "Connection details for specified connection type.", + }, + + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If set, connection_url is verified by actually connecting to the database`, + }, + + "max_open_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of open connections to the database; +a zero uses the default value of two and a +negative value means unlimited`, + }, + + "max_idle_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of idle connections to the database; +a zero uses the value of max_open_connections +and a negative value disables idle connections. +If larger than max_open_connections it will be +reduced to the same size.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +// pathConnectionRead reads out the connection configuration +func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.ConnectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + +func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connURL := data.Get("connection_url").(string) + connType := data.Get("connection_type").(string) + + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } + + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } + + config := dbs.ConnectionConfig{ + ConnectionType: connType, + ConnectionURL: connURL, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + } + + name := data.Get("name").(string) + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + var err error + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("can not change type of existing connection"), nil + } + + db = b.connections[name] + } else { + db, err = dbs.Factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + } + + /* + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ + + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + db.Reset(config) + b.connections[name] = db + + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + + return resp, nil +} + +const pathConfigConnectionHelpSyn = ` +Configure the connection string to talk to PostgreSQL. +` + +const pathConfigConnectionHelpDesc = ` +This path configures the connection string used to connect to PostgreSQL. +The value of the string can be a URL, or a PG style string in the +format of "user=foo host=bar" etc. + +The URL looks like: +"postgresql://user:pass@host:port/dbname" + +When configuring the connection string, the backend will verify its validity. +` diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go new file mode 100644 index 000000000..5cc40a056 --- /dev/null +++ b/builtin/logical/database/path_config_lease.go @@ -0,0 +1,103 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigLease(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "config/lease", + Fields: map[string]*framework.FieldSchema{ + "lease": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Default lease for roles.", + }, + + "lease_max": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Maximum time a credential is valid for.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathLeaseRead, + logical.UpdateOperation: b.pathLeaseWrite, + }, + + HelpSynopsis: pathConfigLeaseHelpSyn, + HelpDescription: pathConfigLeaseHelpDesc, + } +} + +func (b *databaseBackend) pathLeaseWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + leaseRaw := d.Get("lease").(string) + leaseMaxRaw := d.Get("lease_max").(string) + + lease, err := time.ParseDuration(leaseRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + leaseMax, err := time.ParseDuration(leaseMaxRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + + // Store it + entry, err := logical.StorageEntryJSON("config/lease", &configLease{ + Lease: lease, + LeaseMax: leaseMax, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathLeaseRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + lease, err := b.Lease(req.Storage) + + if err != nil { + return nil, err + } + if lease == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "lease": lease.Lease.String(), + "lease_max": lease.LeaseMax.String(), + }, + }, nil +} + +type configLease struct { + Lease time.Duration + LeaseMax time.Duration +} + +const pathConfigLeaseHelpSyn = ` +Configure the default lease information for generated credentials. +` + +const pathConfigLeaseHelpDesc = ` +This configures the default lease information used for credentials +generated by this backend. The lease specifies the duration that a +credential will be valid for, as well as the maximum session for +a set of credentials. + +The format for the lease is "1h" or integer and then unit. The longest +unit is hour. +` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go new file mode 100644 index 000000000..2a2386d01 --- /dev/null +++ b/builtin/logical/database/path_role_create.go @@ -0,0 +1,120 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathRoleCreate(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "creds/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleCreateRead, + }, + + HelpSynopsis: pathRoleCreateReadHelpSyn, + HelpDescription: pathRoleCreateReadHelpDesc, + } +} + +func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.logger.Trace("postgres/pathRoleCreateRead: enter") + defer b.logger.Trace("postgres/pathRoleCreateRead: exit") + + name := data.Get("name").(string) + + // Get the role + b.logger.Trace("postgres/pathRoleCreateRead: getting role") + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + } + + // Determine if we have a lease + b.logger.Trace("postgres/pathRoleCreateRead: getting lease") + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + // Unlike some other backends we need a lease here (can't leave as 0 and + // let core fill it in) because Postgres also expires users as a safety + // measure, so cannot be zero + if lease == nil { + lease = &configLease{ + Lease: b.System().DefaultLeaseTTL(), + } + } + + // Generate the username, password and expiration. PG limits user to 63 characters + displayName := req.DisplayName + if len(displayName) > 26 { + displayName = displayName[:26] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if len(username) > 63 { + username = username[:63] + } + password, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + expiration := time.Now(). + Add(lease.Lease). + Format("2006-01-02 15:04:05-0700") + + // Get our handle + b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") + + b.RLock() + defer b.RUnlock() + db, ok := b.connections[role.DBName] + if !ok { + // TODO: return a resp error instead? + return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) + } + + err = db.CreateUser(role.CreationStatement, username, password, expiration) + if err != nil { + return nil, err + } + + b.logger.Trace("postgres/pathRoleCreateRead: generating secret") + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + "role": name, + }) + resp.Secret.TTL = lease.Lease + return resp, nil +} + +const pathRoleCreateReadHelpSyn = ` +Request database credentials for a certain role. +` + +const pathRoleCreateReadHelpDesc = ` +This path reads database credentials for a certain role. The +database credentials will be generated on demand and will be automatically +revoked when the lease is up. +` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go new file mode 100644 index 000000000..e06518b28 --- /dev/null +++ b/builtin/logical/database/path_roles.go @@ -0,0 +1,161 @@ +package database + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathListRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/?$", + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.pathRoleList, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func pathRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the role.", + }, + + "db_name": { + Type: framework.TypeString, + Description: "Name of the database this role acts on.", + }, + + "creation_statement": { + Type: framework.TypeString, + Description: "SQL string to create a user. See help for more info.", + }, + + "revocation_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleRead, + logical.UpdateOperation: b.pathRoleCreate, + logical.DeleteOperation: b.pathRoleDelete, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "creation_statment": role.CreationStatement, + "revocation_statement": role.RevocationStatement, + }, + }, nil +} + +func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("role/") + if err != nil { + return nil, err + } + + return logical.ListResponse(entries), nil +} + +func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + dbName := data.Get("db_name").(string) + creationStmt := data.Get("creation_statement").(string) + revocationStmt := data.Get("revocation_statement").(string) + + // TODO: Think about preparing the statments to test. + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + DBName: dbName, + CreationStatement: creationStmt, + RevocationStatement: revocationStmt, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +type roleEntry struct { + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` +} + +const pathRoleHelpSyn = ` +Manage the roles that can be created with this backend. +` + +const pathRoleHelpDesc = ` +This path lets you manage the roles that can be created with this backend. + +The "sql" parameter customizes the SQL string used to create the role. +This can be a sequence of SQL queries. Some substitution will be done to the +SQL string for certain keys. The names of the variables must be surrounded +by "{{" and "}}" to be replaced. + + * "name" - The random username generated for the DB user. + + * "password" - The random password generated for the DB user. + + * "expiration" - The timestamp when this user will expire. + +Example of a decent SQL query to use: + + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; + +Note the above user would be able to access everything in schema public. +For more complex GRANT clauses, see the PostgreSQL manual. + +The "revocation_sql" parameter customizes the SQL string used to revoke a user. +Example of a decent revocation SQL query to use: + + REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; + REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; + REVOKE USAGE ON SCHEMA public FROM {{name}}; + DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go new file mode 100644 index 000000000..30c4a6430 --- /dev/null +++ b/builtin/logical/database/secret_creds.go @@ -0,0 +1,147 @@ +package database + +import ( + "errors" + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +const SecretCredsType = "creds" + +func secretCreds(b *databaseBackend) *framework.Secret { + return &framework.Secret{ + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{ + "username": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Username", + }, + + "password": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Password", + }, + }, + + Renew: b.secretCredsRenew, + Revoke: b.secretCredsRevoke, + } +} + +func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + dbName := d.Get("name").(string) + + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + // Get our connection + db, ok := b.connections[dbName] + if !ok { + return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + } + + // Get the lease information + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + if lease == nil { + lease = &configLease{} + } + + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + resp, err := f(req, d) + if err != nil { + return nil, err + } + + // Make sure we increase the VALID UNTIL endpoint for this user. + if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { + expiration := expireTime.Format("2006-01-02 15:04:05-0700") + + err := db.RenewUser(username, expiration) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + var revocationSQL string + var resp *logical.Response + + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + /* TODO: think about how to handle this case. + if !ok { + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + if resp == nil { + resp = &logical.Response{} + } + resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) + } else { + revocationSQL = role.RevocationStatement + } + }*/ + + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + } + + // TODO: Maybe move this down into db package? + switch revocationSQL { + + // This is the default revocation logic. If revocation SQL is provided it + // is simply executed as-is. + case "": + err := db.DefaultRevokeUser(username) + if err != nil { + return nil, err + } + + // We have revocation SQL, execute directly, within a transaction + default: + err := db.CustomRevokeUser(username, revocationSQL) + if err != nil { + return nil, err + } + } + + return resp, nil +} diff --git a/cli/commands.go b/cli/commands.go index 190111177..13f7c8b25 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/aws" "github.com/hashicorp/vault/builtin/logical/cassandra" "github.com/hashicorp/vault/builtin/logical/consul" + "github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/mongodb" "github.com/hashicorp/vault/builtin/logical/mssql" "github.com/hashicorp/vault/builtin/logical/mysql" @@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "mysql": mysql.Factory, "ssh": ssh.Factory, "rabbitmq": rabbitmq.Factory, + "database": database.Factory, }, ShutdownCh: command.MakeShutdownCh(), SighupCh: command.MakeSighupCh(), From 2ec5ab56160fbf689cbab72d201dc001123f40c2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 20 Dec 2016 11:46:20 -0800 Subject: [PATCH 002/162] More work on refactor and cassandra database --- builtin/logical/database/backend.go | 19 -- builtin/logical/database/dbs/cassandra.go | 204 ++++---------- .../database/dbs/connectionproducer.go | 254 ++++++++++++++++++ .../database/dbs/credentialsproducer.go | 79 ++++++ builtin/logical/database/dbs/db.go | 67 +++-- builtin/logical/database/dbs/postgresql.go | 102 ++----- .../database/path_config_connection.go | 10 +- builtin/logical/database/path_config_lease.go | 103 ------- builtin/logical/database/path_role_create.go | 52 +--- builtin/logical/database/path_roles.go | 50 +++- builtin/logical/database/secret_creds.go | 47 ++-- 11 files changed, 553 insertions(+), 434 deletions(-) create mode 100644 builtin/logical/database/dbs/connectionproducer.go create mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/path_config_lease.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 8b7fa3670..3d757df1d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { Paths: []*framework.Path{ pathConfigConnection(&b), - pathConfigLease(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), @@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() { } } -// Lease returns the lease information -func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { - entry, err := s.Get("config/lease") - if err != nil { - return nil, err - } - if entry == nil { - return nil, nil - } - - var result configLease - if err := entry.DecodeJSON(&result); err != nil { - return nil, err - } - - return &result, nil -} - func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 8c7a068be..a8889032f 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -1,25 +1,20 @@ package dbs import ( - "crypto/tls" - "database/sql" "fmt" "strings" "sync" - "time" "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" + "github.com/hashicorp/vault/helper/strutil" ) type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we // can close it and use a new connection; hence the lock - session *gocql.Session - config ConnectionConfig - + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -27,168 +22,85 @@ func (c *Cassandra) Type() string { return cassandraTypeName } -func (c *Cassandra) Connection() (*gocql.Session, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! - if c.session != nil { - return c.session, nil - } - - session, err := createSession(c.config) +func (c *Cassandra) getConnection() (*gocql.Session, error) { + session, err := c.Connection() if err != nil { return nil, err } - // Store the session in backend for reuse - c.session = session - - return session, nil + return session.(*gocql.Session), nil } -func (p *Cassandra) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.session != nil { - p.session.Close() - } - - p.session = nil -} - -func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() -} - -func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + session, err := c.getConnection() if err != nil { return err } // TODO: This is racey // Grab a read lock - p.RLock() - defer p.RUnlock() + c.RLock() + defer c.RUnlock() - return nil -} + // Set consistency + /* if .Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency) + if err != nil { + return err + } -func (p *Cassandra) RenewUser(username, expiration string) error { - db, err := p.Connection() - if err != nil { - return err + session.SetConsistency(consistencyValue) + }*/ + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err = session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + } + return err + } } - // TODO: This is Racey - // Grab the read lock - p.RLock() - defer p.RUnlock() return nil } -func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() +func (c *Cassandra) RenewUser(username, expiration string) error { + // NOOP + return nil +} + +func (c *Cassandra) RevokeUser(username, revocationSQL string) error { + session, err := c.getConnection() if err != nil { return err } // TODO: this is Racey - p.RLock() - defer p.RUnlock() + c.RLock() + defer c.RUnlock() + + err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() + if err != nil { + return fmt.Errorf("error removing user %s", username) + } return nil } - -func (p *Cassandra) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() - - return nil -} - -func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, - } - - clusterConfig.ProtoVersion = cfg.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second - - if cfg.TLS { - var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey - } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS - - if cfg.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() - if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) - } - - // Verify the info - err = session.Query(`LIST USERS`).Exec() - if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) - } - - return session, nil -} diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go new file mode 100644 index 000000000..adecfd55a --- /dev/null +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -0,0 +1,254 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" + "github.com/mitchellh/mapstructure" +) + +type ConnectionProducer interface { + Connection() (interface{}, error) + Close() + // TODO: Should we make this immutable instead? + Reset(*DatabaseConfig) error +} + +// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases +type sqlConnectionDetails struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` +} + +type sqlConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *sqlConnectionDetails + + db *sql.DB + sync.Mutex +} + +func (cp *sqlConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.db != nil { + if err := cp.db.Ping(); err == nil { + return cp.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + cp.db.Close() + } + + // Otherwise, attempt to make connection + conn := cp.connDetails.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + cp.db, err = sql.Open(cp.config.DatabaseType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) + cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) + + return cp.db, nil +} + +func (cp *sqlConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.db != nil { + cp.db.Close() + } + + cp.db = nil +} + +func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + + var details *sqlConnectionDetails + err := mapstructure.Decode(config.ConnectionDetails, &details) + if err != nil { + return err + } + + cp.connDetails = details + cp.config = config + + cp.Unlock() + + cp.Close() + _, err = cp.Connection() + return err +} + +// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra +type cassandraConnectionDetails struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` +} + +type cassandraConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *cassandraConnectionDetails + + session *gocql.Session + sync.Mutex +} + +func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.session != nil { + return cp.session, nil + } + + session, err := cp.createSession(cp.connDetails) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + cp.session = session + + return session, nil +} + +func (cp *cassandraConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.session != nil { + cp.session.Close() + } + + cp.session = nil +} + +func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + cp.config = config + cp.Unlock() + + cp.Close() + _, err := cp.Connection() + + return err +} + +func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go new file mode 100644 index 000000000..20210c2e8 --- /dev/null +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -0,0 +1,79 @@ +package dbs + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) string +} + +// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. +type sqlCredentialsProducer struct { + displayNameLen int + usernameLen int +} + +func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { + // Generate the username, password and expiration. PG limits user to 63 characters + if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { + displayName = displayName[:scg.displayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scg.usernameLen > 0 && len(username) > scg.usernameLen { + username = username[:scg.usernameLen] + } + + return username, nil +} + +func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return time.Now(). + Add(ttl). + Format("2006-01-02 15:04:05-0700") +} + +type cassandraCredentialsProducer struct{} + +func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return "" +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index ee7b15b64..9d261ff42 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -1,10 +1,11 @@ package dbs import ( - "database/sql" "errors" "fmt" "strings" + + "github.com/mitchellh/mapstructure" ) const ( @@ -16,11 +17,47 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf ConnectionConfig) (DatabaseType, error) { - switch conf.ConnectionType { +func Factory(conf *DatabaseConfig) (DatabaseType, error) { + switch conf.DatabaseType { case postgreSQLTypeName: + var details *sqlConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &sqlConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 23, + usernameLen: 63, + } + return &PostgreSQL{ - config: conf, + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + }, nil + + case cassandraTypeName: + var details *cassandraConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &cassandraConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &cassandraCredentialsProducer{} + + return &Cassandra{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, }, nil } @@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) { type DatabaseType interface { Type() string - Connection() (*sql.DB, error) - Close() - Reset(ConnectionConfig) (*sql.DB, error) - CreateUser(createStmt, username, password, expiration string) error + CreateUser(createStmt, rollbackStmt, username, password, expiration string) error RenewUser(username, expiration string) error - CustomRevokeUser(username, revocationSQL string) error - DefaultRevokeUser(username string) error + RevokeUser(username, revocationStmt string) error + + ConnectionProducer + CredentialsProducer } -type ConnectionConfig struct { - ConnectionType string `json:"type" structs:"type" mapstructure:"type"` - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +type DatabaseConfig struct { + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` } // Query templates a query for us. diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index ea7d08f8a..e050e30bf 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -11,9 +11,10 @@ import ( ) type PostgreSQL struct { - db *sql.DB - config ConnectionConfig + db *sql.DB + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string { return postgreSQLTypeName } -func (p *PostgreSQL) Connection() (*sql.DB, error) { - // Grab the write lock - p.Lock() - defer p.Unlock() - - // If we already have a DB, we got it! - if p.db != nil { - if err := p.db.Ping(); err == nil { - return p.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - p.db.Close() - } - - // Otherwise, attempt to make connection - conn := p.config.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } else { - conn += " timezone=utc" - } - - var err error - p.db, err = sql.Open("postgres", conn) +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() if err != nil { return nil, err } - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - p.db.SetMaxOpenConns(p.config.MaxOpenConnections) - p.db.SetMaxIdleConns(p.config.MaxIdleConnections) - - return p.db, nil + return db.(*sql.DB), nil } -func (p *PostgreSQL) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.db != nil { - p.db.Close() - } - - p.db = nil -} - -func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() -} - -func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin } func (p *PostgreSQL) RenewUser(username, expiration string) error { - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() +func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + if revocationStmt == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, revocationStmt) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { + db, err := p.getConnection() if err != nil { return err } - // TODO: this is Racey - p.RLock() - defer p.RUnlock() tx, err := db.Begin() if err != nil { @@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { return nil } -func (p *PostgreSQL) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() if err != nil { return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be017ea35..d4a969a69 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo return nil, nil } - var config dbs.ConnectionConfig + var config dbs.DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo } func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connURL := data.Get("connection_url").(string) connType := data.Get("connection_type").(string) + connDetails := data.Get("connection_details").(map[string]interface{}) maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew maxIdleConns = maxOpenConns } - config := dbs.ConnectionConfig{ - ConnectionType: connType, - ConnectionURL: connURL, + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: connDetails, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, } diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go deleted file mode 100644 index 5cc40a056..000000000 --- a/builtin/logical/database/path_config_lease.go +++ /dev/null @@ -1,103 +0,0 @@ -package database - -import ( - "fmt" - "time" - - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/logical/framework" -) - -func pathConfigLease(b *databaseBackend) *framework.Path { - return &framework.Path{ - Pattern: "config/lease", - Fields: map[string]*framework.FieldSchema{ - "lease": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Default lease for roles.", - }, - - "lease_max": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Maximum time a credential is valid for.", - }, - }, - - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathLeaseRead, - logical.UpdateOperation: b.pathLeaseWrite, - }, - - HelpSynopsis: pathConfigLeaseHelpSyn, - HelpDescription: pathConfigLeaseHelpDesc, - } -} - -func (b *databaseBackend) pathLeaseWrite( - req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - leaseRaw := d.Get("lease").(string) - leaseMaxRaw := d.Get("lease_max").(string) - - lease, err := time.ParseDuration(leaseRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - leaseMax, err := time.ParseDuration(leaseMaxRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - - // Store it - entry, err := logical.StorageEntryJSON("config/lease", &configLease{ - Lease: lease, - LeaseMax: leaseMax, - }) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } - - return nil, nil -} - -func (b *databaseBackend) pathLeaseRead( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - lease, err := b.Lease(req.Storage) - - if err != nil { - return nil, err - } - if lease == nil { - return nil, nil - } - - return &logical.Response{ - Data: map[string]interface{}{ - "lease": lease.Lease.String(), - "lease_max": lease.LeaseMax.String(), - }, - }, nil -} - -type configLease struct { - Lease time.Duration - LeaseMax time.Duration -} - -const pathConfigLeaseHelpSyn = ` -Configure the default lease information for generated credentials. -` - -const pathConfigLeaseHelpDesc = ` -This configures the default lease information used for credentials -generated by this backend. The lease specifies the duration that a -credential will be valid for, as well as the maximum session for -a set of credentials. - -The format for the lease is "1h" or integer and then unit. The longest -unit is hour. -` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 2a2386d01..15ca915ba 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,9 +2,7 @@ package database import ( "fmt" - "time" - "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" _ "github.com/lib/pq" @@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } - // Determine if we have a lease - b.logger.Trace("postgres/pathRoleCreateRead: getting lease") - lease, err := b.Lease(req.Storage) - if err != nil { - return nil, err - } - // Unlike some other backends we need a lease here (can't leave as 0 and - // let core fill it in) because Postgres also expires users as a safety - // measure, so cannot be zero - if lease == nil { - lease = &configLease{ - Lease: b.System().DefaultLeaseTTL(), - } - } - // Generate the username, password and expiration. PG limits user to 63 characters - displayName := req.DisplayName - if len(displayName) > 26 { - displayName = displayName[:26] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if len(username) > 63 { - username = username[:63] - } - password, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - expiration := time.Now(). - Add(lease.Lease). - Format("2006-01-02 15:04:05-0700") // Get our handle b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") @@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) } - err = db.CreateUser(role.CreationStatement, username, password, expiration) + username, err := db.GenerateUsername(req.DisplayName) + if err != nil { + return nil, err + } + + password, err := db.GeneratePassword() + if err != nil { + return nil, err + } + + expiration := db.GenerateExpiration(role.DefaultTTL) + + err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) if err != nil { return nil, err } @@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo "username": username, "role": name, }) - resp.Secret.TTL = lease.Lease + resp.Secret.TTL = role.DefaultTTL return resp, nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index e06518b28..dc8c6805a 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -1,6 +1,9 @@ package database import ( + "fmt" + "time" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path { array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, }, + + "rollback_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + + "default_ttl": { + Type: framework.TypeString, + Description: "Default ttl for role.", + }, + + "max_ttl": { + Type: framework.TypeString, + Description: "Maximum time a credential is valid for", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie Data: map[string]interface{}{ "creation_statment": role.CreationStatement, "revocation_statement": role.RevocationStatement, + "rollback_statement": role.RollbackStatement, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F dbName := data.Get("db_name").(string) creationStmt := data.Get("creation_statement").(string) revocationStmt := data.Get("revocation_statement").(string) + rollbackStmt := data.Get("rollback_statement").(string) + defaultTTLRaw := data.Get("default_ttl").(string) + maxTTLRaw := data.Get("max_ttl").(string) + + defaultTTL, err := time.ParseDuration(defaultTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid default_ttl: %s", err)), nil + } + maxTTL, err := time.ParseDuration(maxTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_ttl: %s", err)), nil + } // TODO: Think about preparing the statments to test. @@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F DBName: dbName, CreationStatement: creationStmt, RevocationStatement: revocationStmt, + RollbackStatement: rollbackStmt, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 30c4a6430..120804e91 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "github.com/hashicorp/vault/logical" @@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret { } func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - dbName := d.Get("name").(string) - // Get the username from the internal data usernameRaw, ok := req.Secret.InternalData["username"] if !ok { @@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } username, ok := usernameRaw.(string) - // Get our connection - db, ok := b.connections[dbName] + roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - // Get the lease information - lease, err := b.Lease(req.Storage) + role, err := b.Role(req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } - if lease == nil { - lease = &configLease{} + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) resp, err := f(req, d) if err != nil { return nil, err } + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + } + // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { expiration := expireTime.Format("2006-01-02 15:04:05-0700") @@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) } - // TODO: Maybe move this down into db package? - switch revocationSQL { - - // This is the default revocation logic. If revocation SQL is provided it - // is simply executed as-is. - case "": - err := db.DefaultRevokeUser(username) - if err != nil { - return nil, err - } - - // We have revocation SQL, execute directly, within a transaction - default: - err := db.CustomRevokeUser(username, revocationSQL) - if err != nil { - return nil, err - } + err = db.RevokeUser(username, revocationSQL) + if err != nil { + return nil, err } return resp, nil From 46aa7142c1bb61fa413ba94d309d3eb240d86780 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 10:18:10 -0800 Subject: [PATCH 003/162] Add mysql database type --- .../database/dbs/connectionproducer.go | 1 + .../database/dbs/credentialsproducer.go | 14 +- builtin/logical/database/dbs/mysql.go | 136 ++++++++++++++++++ .../database/path_config_connection.go | 2 +- 4 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 builtin/logical/database/dbs/mysql.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index adecfd55a..dc8f6c82c 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,6 +8,7 @@ import ( "sync" "time" + _ "github.com/go-sql-driver/mysql" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 20210c2e8..94fce6275 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -20,24 +20,24 @@ type sqlCredentialsProducer struct { usernameLen int } -func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { +func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { // Generate the username, password and expiration. PG limits user to 63 characters - if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { - displayName = displayName[:scg.displayNameLen] + if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { + displayName = displayName[:scp.displayNameLen] } userUUID, err := uuid.GenerateUUID() if err != nil { return "", err } username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scg.usernameLen > 0 && len(username) > scg.usernameLen { - username = username[:scg.usernameLen] + if scp.usernameLen > 0 && len(username) > scp.usernameLen { + username = username[:scp.usernameLen] } return username, nil } -func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { +func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { password, err := uuid.GenerateUUID() if err != nil { return "", err @@ -46,7 +46,7 @@ func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { +func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { return time.Now(). Add(ttl). Format("2006-01-02 15:04:05-0700") diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go new file mode 100644 index 000000000..314d4c929 --- /dev/null +++ b/builtin/logical/database/dbs/mysql.go @@ -0,0 +1,136 @@ +package dbs + +import ( + "database/sql" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" +) + +const defaultRevocationSQL = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` + +type MySQL struct { + db *sql.DB + + ConnectionProducer + CredentialsProducer + sync.RWMutex +} + +func (p *MySQL) Type() string { + return postgreSQLTypeName +} + +func (p *MySQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// NOOP +func (p *MySQL) RenewUser(username, expiration string) error { + return nil +} + +func (p *MySQL) RevokeUser(username, revocationStmt string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // Grab the read lock + p.RLock() + defer p.RUnlock() + + // Use a default SQL statement for revocation if one cannot be fetched from the role + + if revocationStmt == "" { + revocationStmt = defaultRevocationSQL + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index d4a969a69..90dfea4cd 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -124,7 +124,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew // Don't allow the connection type to change if b.connections[name].Type() != connType { - return logical.ErrorResponse("can not change type of existing connection"), nil + return logical.ErrorResponse("Can not change type of existing connection."), nil } db = b.connections[name] From 1f009518cdda280cb41d4e230a592d98fbbbac32 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 10:53:39 -0800 Subject: [PATCH 004/162] s/Statement/Statements/ --- builtin/logical/database/dbs/cassandra.go | 13 ++++-- .../database/dbs/connectionproducer.go | 3 ++ builtin/logical/database/dbs/mysql.go | 14 +++--- builtin/logical/database/dbs/postgresql.go | 14 +++--- builtin/logical/database/path_role_create.go | 3 +- builtin/logical/database/path_roles.go | 46 +++++++++---------- 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index a8889032f..7a06e1314 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -9,6 +9,11 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) +const ( + defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultRollbackCQL = `DROP USER '{{username}}';` +) + type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we @@ -31,7 +36,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection session, err := c.getConnection() if err != nil { @@ -54,7 +59,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp }*/ // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -65,7 +70,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -88,7 +93,7 @@ func (c *Cassandra) RenewUser(username, expiration string) error { return nil } -func (c *Cassandra) RevokeUser(username, revocationSQL string) error { +func (c *Cassandra) RevokeUser(username, revocationStmts string) error { session, err := c.getConnection() if err != nil { return err diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index dc8f6c82c..5c606996d 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,7 +8,10 @@ import ( "sync" "time" + // Import sql drivers _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 314d4c929..0a18683ea 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) -const defaultRevocationSQL = ` +const defaultRevocationStmts = ` REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` @@ -34,7 +34,7 @@ func (p *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -54,7 +54,7 @@ func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expirat defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -86,7 +86,7 @@ func (p *MySQL) RenewUser(username, expiration string) error { return nil } -func (p *MySQL) RevokeUser(username, revocationStmt string) error { +func (p *MySQL) RevokeUser(username, revocationStmts string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -99,8 +99,8 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmt == "" { - revocationStmt = defaultRevocationSQL + if revocationStmts == "" { + revocationStmts = defaultRevocationStmts } // Start a transaction @@ -110,7 +110,7 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { } defer tx.Rollback() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index e050e30bf..01fb3cd70 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -31,7 +31,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -56,7 +56,7 @@ func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, ex // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -115,19 +115,19 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { // Grab the read lock p.RLock() defer p.RUnlock() - if revocationStmt == "" { + if revocationStmts == "" { return p.defaultRevokeUser(username) } - return p.customRevokeUser(username, revocationStmt) + return p.customRevokeUser(username, revocationStmts) } -func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { db, err := p.getConnection() if err != nil { return err @@ -141,7 +141,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 15ca915ba..b1cce97f3 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -5,7 +5,6 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - _ "github.com/lib/pq" ) func pathRoleCreate(b *databaseBackend) *framework.Path { @@ -68,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo expiration := db.GenerateExpiration(role.DefaultTTL) - err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) + err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index dc8c6805a..994d084f0 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -35,12 +35,12 @@ func pathRoles(b *databaseBackend) *framework.Path { Description: "Name of the database this role acts on.", }, - "creation_statement": { + "creation_statements": { Type: framework.TypeString, Description: "SQL string to create a user. See help for more info.", }, - "revocation_statement": { + "revocation_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -48,7 +48,7 @@ func pathRoles(b *databaseBackend) *framework.Path { will be substituted.`, }, - "rollback_statement": { + "rollback_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -98,11 +98,11 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statment": role.CreationStatement, - "revocation_statement": role.RevocationStatement, - "rollback_statement": role.RollbackStatement, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), + "creation_statments": role.CreationStatements, + "revocation_statements": role.RevocationStatements, + "rollback_statements": role.RollbackStatements, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -119,9 +119,9 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) dbName := data.Get("db_name").(string) - creationStmt := data.Get("creation_statement").(string) - revocationStmt := data.Get("revocation_statement").(string) - rollbackStmt := data.Get("rollback_statement").(string) + creationStmts := data.Get("creation_statements").(string) + revocationStmts := data.Get("revocation_statements").(string) + rollbackStmts := data.Get("rollback_statements").(string) defaultTTLRaw := data.Get("default_ttl").(string) maxTTLRaw := data.Get("max_ttl").(string) @@ -140,12 +140,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - CreationStatement: creationStmt, - RevocationStatement: revocationStmt, - RollbackStatement: rollbackStmt, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, + DBName: dbName, + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -158,12 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` From 8e8f260d96f9522aaaed98ecaf75fd03c963cbc3 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 11:28:30 -0800 Subject: [PATCH 005/162] Add max connection lifetime param and set consistancy on cassandra session --- .../database/dbs/connectionproducer.go | 13 +++++++- builtin/logical/database/dbs/db.go | 10 ++++--- .../database/path_config_connection.go | 30 ++++++++++++------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 5c606996d..e1a7ae9bb 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -78,6 +78,7 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { // since the request rate shouldn't be high. cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) + cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) return cp.db, nil } @@ -127,7 +128,7 @@ type cassandraConnectionDetails struct { ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` } type cassandraConnectionProducer struct { @@ -248,6 +249,16 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet return nil, fmt.Errorf("Error creating session: %s", err) } + // Set consistency + if cfg.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if err != nil { + return nil, err + } + + session.SetConsistency(consistencyValue) + } + // Verify the info err = session.Query(`LIST USERS`).Exec() if err != nil { diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 9d261ff42..e901f69f8 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/mitchellh/mapstructure" ) @@ -75,10 +76,11 @@ type DatabaseType interface { } type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } // Query templates a query for us. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 90dfea4cd..06cf1dd4c 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -2,12 +2,12 @@ package database import ( "fmt" + "time" "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - _ "github.com/lib/pq" ) func pathConfigConnection(b *databaseBackend) *framework.Path { @@ -24,11 +24,6 @@ func pathConfigConnection(b *databaseBackend) *framework.Path { Description: "DB type (e.g. postgres)", }, - "connection_url": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "DB connection string", - }, - "connection_details": &framework.FieldSchema{ Type: framework.TypeMap, Description: "Connection details for specified connection type.", @@ -55,6 +50,12 @@ and a negative value disables idle connections. If larger than max_open_connections it will be reduced to the same size.`, }, + + "max_connection_lifetime": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -105,11 +106,19 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew maxIdleConns = maxOpenConns } + maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) + maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_connection_lifetime: %s", err)), nil + } + config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: connDetails, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, + DatabaseType: connType, + ConnectionDetails: connDetails, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + MaxConnectionLifetime: maxConnLifetime, } name := data.Get("name").(string) @@ -118,7 +127,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew b.Lock() defer b.Unlock() - var err error var db dbs.DatabaseType if _, ok := b.connections[name]; ok { From 24ddea995487c632bc1f5400fd6046041eca2760 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Feb 2017 17:32:08 -0800 Subject: [PATCH 006/162] Add mysql into the factory --- builtin/logical/database/dbs/db.go | 23 +++++++++++++++++++ builtin/logical/database/dbs/mysql.go | 2 +- .../database/path_config_connection.go | 6 ++--- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e901f69f8..d648b776f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -11,6 +11,7 @@ import ( const ( postgreSQLTypeName = "postgres" + mySQLTypeName = "mysql" cassandraTypeName = "cassandra" ) @@ -42,6 +43,28 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { CredentialsProducer: credsProducer, }, nil + case mySQLTypeName: + var details *sqlConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &sqlConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 4, + usernameLen: 16, + } + + return &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + }, nil + case cassandraTypeName: var details *cassandraConnectionDetails err := mapstructure.Decode(conf.ConnectionDetails, &details) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0a18683ea..ce6cdac92 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -22,7 +22,7 @@ type MySQL struct { } func (p *MySQL) Type() string { - return postgreSQLTypeName + return mySQLTypeName } func (p *MySQL) getConnection() (*sql.DB, error) { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 06cf1dd4c..c3f72b743 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -52,7 +52,8 @@ reduced to the same size.`, }, "max_connection_lifetime": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeString, + Default: "0s", Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, @@ -91,7 +92,6 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) - connDetails := data.Get("connection_details").(map[string]interface{}) maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -115,7 +115,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew config := &dbs.DatabaseConfig{ DatabaseType: connType, - ConnectionDetails: connDetails, + ConnectionDetails: data.Raw, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, From 29e07ac9e81fdf34890fc0889f099eff61f78ef8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Feb 2017 14:31:15 -0800 Subject: [PATCH 007/162] Fix mysql connections --- builtin/logical/database/dbs/connectionproducer.go | 2 -- builtin/logical/database/path_config_connection.go | 5 ----- 2 files changed, 7 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index e1a7ae9bb..b53bb0c75 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -64,8 +64,6 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { } else { conn += "?timezone=utc" } - } else { - conn += " timezone=utc" } var err error diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c3f72b743..9fe926050 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -24,11 +24,6 @@ func pathConfigConnection(b *databaseBackend) *framework.Path { Description: "DB type (e.g. postgres)", }, - "connection_details": &framework.FieldSchema{ - Type: framework.TypeMap, - Description: "Connection details for specified connection type.", - }, - "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, From bba832e6bfe53abdd0cd483749b16f0d57db760c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Feb 2017 16:51:59 -0800 Subject: [PATCH 008/162] Make db instances immutable and add a reset path to tear down and create a new database instance with an updated config --- builtin/logical/database/backend.go | 1 + .../database/dbs/connectionproducer.go | 152 +++++++----------- builtin/logical/database/dbs/db.go | 30 ++-- .../database/path_config_connection.go | 66 +++++++- 4 files changed, 125 insertions(+), 124 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 3d757df1d..fe853d3fb 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -25,6 +25,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), + pathResetConnection(&b), }, Secrets: []*framework.Secret{ diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index b53bb0c75..1e66d27f6 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -15,47 +15,40 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" - "github.com/mitchellh/mapstructure" ) type ConnectionProducer interface { Connection() (interface{}, error) Close() - // TODO: Should we make this immutable instead? - Reset(*DatabaseConfig) error } // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionDetails struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` -} - type sqlConnectionProducer struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *sqlConnectionDetails db *sql.DB sync.Mutex } -func (cp *sqlConnectionProducer) Connection() (interface{}, error) { +func (c *sqlConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.db != nil { - if err := cp.db.Ping(); err == nil { - return cp.db, nil + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil } // If the ping was unsuccessful, close it and ignore errors as we'll be // reestablishing anyways - cp.db.Close() + c.db.Close() } // Otherwise, attempt to make connection - conn := cp.connDetails.ConnectionURL + conn := c.ConnectionURL // Ensure timezone is set to UTC for all the conenctions if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { @@ -67,54 +60,33 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { } var err error - cp.db, err = sql.Open(cp.config.DatabaseType, conn) + c.db, err = sql.Open(c.config.DatabaseType, conn) if err != nil { return nil, err } // Set some connection pool settings. We don't need much of this, // since the request rate shouldn't be high. - cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) - cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) - cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) + c.db.SetMaxOpenConns(c.config.MaxOpenConnections) + c.db.SetMaxIdleConns(c.config.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) - return cp.db, nil + return c.db, nil } -func (cp *sqlConnectionProducer) Close() { +func (c *sqlConnectionProducer) Close() { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() - if cp.db != nil { - cp.db.Close() + if c.db != nil { + c.db.Close() } - cp.db = nil + c.db = nil } -func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { - // Grab the write lock - cp.Lock() - - var details *sqlConnectionDetails - err := mapstructure.Decode(config.ConnectionDetails, &details) - if err != nil { - return err - } - - cp.connDetails = details - cp.config = config - - cp.Unlock() - - cp.Close() - _, err = cp.Connection() - return err -} - -// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra -type cassandraConnectionDetails struct { +type cassandraConnectionProducer struct { Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Username string `json:"username" structs:"username" mapstructure:"username"` Password string `json:"password" structs:"password" mapstructure:"password"` @@ -127,90 +99,74 @@ type cassandraConnectionDetails struct { ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` -} -type cassandraConnectionProducer struct { config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *cassandraConnectionDetails session *gocql.Session sync.Mutex } -func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.session != nil { - return cp.session, nil + if c.session != nil { + return c.session, nil } - session, err := cp.createSession(cp.connDetails) + session, err := c.createSession() if err != nil { return nil, err } // Store the session in backend for reuse - cp.session = session + c.session = session return session, nil } -func (cp *cassandraConnectionProducer) Close() { +func (c *cassandraConnectionProducer) Close() { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() - if cp.session != nil { - cp.session.Close() + if c.session != nil { + c.session.Close() } - cp.session = nil + c.session = nil } -func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { - // Grab the write lock - cp.Lock() - cp.config = config - cp.Unlock() - - cp.Close() - _, err := cp.Connection() - - return err -} - -func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) +func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, + Username: c.Username, + Password: c.Password, } - clusterConfig.ProtoVersion = cfg.ProtocolVersion + clusterConfig.ProtoVersion = c.ProtocolVersion if clusterConfig.ProtoVersion == 0 { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - if cfg.TLS { + if c.TLS { var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA } parsedCertBundle, err := certBundle.ToParsedCertBundle() @@ -222,11 +178,11 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet if err != nil || tlsConfig == nil { return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + tlsConfig.InsecureSkipVerify = c.InsecureTLS - if cfg.TLSMinVersion != "" { + if c.TLSMinVersion != "" { var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] if !ok { return nil, fmt.Errorf("invalid 'tls_min_version' in config") } @@ -248,8 +204,8 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet } // Set consistency - if cfg.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index d648b776f..4c04c0fd4 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -22,16 +22,12 @@ var ( func Factory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 23, @@ -44,16 +40,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 4, @@ -66,16 +58,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var details *cassandraConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *cassandraConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &cassandraConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 9fe926050..085113fe9 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "errors" "fmt" "time" @@ -10,6 +11,64 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +func pathResetConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionReset, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return nil, errors.New("No database name set") + } + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + db, ok := b.connections[name] + if !ok { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + + db.Close() + db, err = dbs.Factory(&config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db + + return nil, nil +} + func pathConfigConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), @@ -129,13 +188,13 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew if b.connections[name].Type() != connType { return logical.ErrorResponse("Can not change type of existing connection."), nil } - - db = b.connections[name] } else { db, err = dbs.Factory(config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + + b.connections[name] = db } /* @@ -166,9 +225,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew } // Reset the DB connection - db.Reset(config) - b.connections[name] = db - resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") From bc53e119ca1fc50e5a7896a8a77542b9fa02aace Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 3 Mar 2017 15:07:41 -0800 Subject: [PATCH 009/162] rename mysql variable --- builtin/logical/database/dbs/mysql.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index ce6cdac92..30452ca54 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -21,12 +21,12 @@ type MySQL struct { sync.RWMutex } -func (p *MySQL) Type() string { +func (m *MySQL) Type() string { return mySQLTypeName } -func (p *MySQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() +func (m *MySQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() if err != nil { return nil, err } @@ -34,17 +34,17 @@ func (p *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection - db, err := p.getConnection() + db, err := m.getConnection() if err != nil { return err } // TODO: This is racey // Grab a read lock - p.RLock() - defer p.RUnlock() + m.RLock() + defer m.RUnlock() // Start a transaction tx, err := db.Begin() @@ -82,20 +82,20 @@ func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir } // NOOP -func (p *MySQL) RenewUser(username, expiration string) error { +func (m *MySQL) RenewUser(username, expiration string) error { return nil } -func (p *MySQL) RevokeUser(username, revocationStmts string) error { +func (m *MySQL) RevokeUser(username, revocationStmts string) error { // Get the connection - db, err := p.getConnection() + db, err := m.getConnection() if err != nil { return err } // Grab the read lock - p.RLock() - defer p.RUnlock() + m.RLock() + defer m.RUnlock() // Use a default SQL statement for revocation if one cannot be fetched from the role From c959882b938b59ae6d489f0f3a522fe3e65ee824 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 13:48:29 -0800 Subject: [PATCH 010/162] Update locking functionaility --- builtin/logical/database/dbs/cassandra.go | 30 ++++++------------- .../database/dbs/connectionproducer.go | 14 ++++----- builtin/logical/database/dbs/mysql.go | 22 ++++++-------- builtin/logical/database/dbs/postgresql.go | 27 ++++++++--------- 4 files changed, 36 insertions(+), 57 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 7a06e1314..9c5607e0d 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -3,7 +3,6 @@ package dbs import ( "fmt" "strings" - "sync" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/strutil" @@ -20,7 +19,6 @@ type Cassandra struct { // can close it and use a new connection; hence the lock ConnectionProducer CredentialsProducer - sync.RWMutex } func (c *Cassandra) Type() string { @@ -28,7 +26,7 @@ func (c *Cassandra) Type() string { } func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.Connection() + session, err := c.connection() if err != nil { return nil, err } @@ -37,27 +35,16 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { } func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + // Get the connection session, err := c.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - c.RLock() - defer c.RUnlock() - - // Set consistency - /* if .Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency) - if err != nil { - return err - } - - session.SetConsistency(consistencyValue) - }*/ - // Execute each query for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) @@ -94,13 +81,14 @@ func (c *Cassandra) RenewUser(username, expiration string) error { } func (c *Cassandra) RevokeUser(username, revocationStmts string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + session, err := c.getConnection() if err != nil { return err } - // TODO: this is Racey - c.RLock() - defer c.RUnlock() err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() if err != nil { diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 1e66d27f6..268ab615c 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -18,8 +18,10 @@ import ( ) type ConnectionProducer interface { - Connection() (interface{}, error) Close() + + sync.Locker + connection() (interface{}, error) } // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases @@ -32,12 +34,8 @@ type sqlConnectionProducer struct { sync.Mutex } -func (c *sqlConnectionProducer) Connection() (interface{}, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! +func (c *sqlConnectionProducer) connection() (interface{}, error) { + // If we already have a DB, test it and return if c.db != nil { if err := c.db.Ping(); err == nil { return c.db, nil @@ -106,7 +104,7 @@ type cassandraConnectionProducer struct { sync.Mutex } -func (c *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) connection() (interface{}, error) { // Grab the write lock c.Lock() defer c.Unlock() diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 30452ca54..b5574d1a5 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -3,7 +3,6 @@ package dbs import ( "database/sql" "strings" - "sync" "github.com/hashicorp/vault/helper/strutil" ) @@ -18,7 +17,6 @@ type MySQL struct { ConnectionProducer CredentialsProducer - sync.RWMutex } func (m *MySQL) Type() string { @@ -26,7 +24,7 @@ func (m *MySQL) Type() string { } func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() + db, err := m.connection() if err != nil { return nil, err } @@ -35,17 +33,16 @@ func (m *MySQL) getConnection() (*sql.DB, error) { } func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + m.Lock() + defer m.Unlock() + // Get the connection db, err := m.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - m.RLock() - defer m.RUnlock() - // Start a transaction tx, err := db.Begin() if err != nil { @@ -87,18 +84,17 @@ func (m *MySQL) RenewUser(username, expiration string) error { } func (m *MySQL) RevokeUser(username, revocationStmts string) error { + // Grab the read lock + m.Lock() + defer m.Unlock() + // Get the connection db, err := m.getConnection() if err != nil { return err } - // Grab the read lock - m.RLock() - defer m.RUnlock() - // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { revocationStmts = defaultRevocationStmts } diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 01fb3cd70..32c049721 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "strings" - "sync" "github.com/hashicorp/vault/helper/strutil" "github.com/lib/pq" @@ -15,7 +14,6 @@ type PostgreSQL struct { ConnectionProducer CredentialsProducer - sync.RWMutex } func (p *PostgreSQL) Type() string { @@ -23,7 +21,7 @@ func (p *PostgreSQL) Type() string { } func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() + db, err := p.connection() if err != nil { return nil, err } @@ -32,17 +30,16 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { } func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + // Get the connection db, err := p.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - p.RLock() - defer p.RUnlock() - // Start a transaction // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") tx, err := db.Begin() @@ -89,14 +86,14 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, } func (p *PostgreSQL) RenewUser(username, expiration string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + db, err := p.getConnection() if err != nil { return err } - // TODO: This is Racey - // Grab the read lock - p.RLock() - defer p.RUnlock() query := fmt.Sprintf( "ALTER ROLE %s VALID UNTIL '%s';", @@ -116,9 +113,9 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { } func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() + // Grab the lock + p.Lock() + defer p.Unlock() if revocationStmts == "" { return p.defaultRevokeUser(username) From 919155ab12855a483416ca3ac4f3d42aea3c3de7 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 15:33:05 -0800 Subject: [PATCH 011/162] Remove double lock --- builtin/logical/database/dbs/connectionproducer.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 268ab615c..82da37cc7 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -105,11 +105,7 @@ type cassandraConnectionProducer struct { } func (c *cassandraConnectionProducer) connection() (interface{}, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! + // If we already have a DB, return it if c.session != nil { return c.session, nil } From 843d5842545cc592199ce26f94aa47a6172b00f4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 15:34:23 -0800 Subject: [PATCH 012/162] Remove unused sql object --- builtin/logical/database/dbs/mysql.go | 2 -- builtin/logical/database/dbs/postgresql.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index b5574d1a5..0cf77062c 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -13,8 +13,6 @@ const defaultRevocationStmts = ` ` type MySQL struct { - db *sql.DB - ConnectionProducer CredentialsProducer } diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 32c049721..468746fc4 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -10,8 +10,6 @@ import ( ) type PostgreSQL struct { - db *sql.DB - ConnectionProducer CredentialsProducer } From 3976a2a0a6750f46582dfe1e8502e3118c234854 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 16:48:17 -0800 Subject: [PATCH 013/162] Pass statements object --- builtin/logical/database/dbs/cassandra.go | 10 ++-- builtin/logical/database/dbs/db.go | 13 +++-- builtin/logical/database/dbs/mysql.go | 13 ++--- builtin/logical/database/dbs/postgresql.go | 12 ++--- builtin/logical/database/path_role_create.go | 2 +- builtin/logical/database/path_roles.go | 50 +++++++++++++------- 6 files changed, 62 insertions(+), 38 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 9c5607e0d..9956372d6 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -34,7 +34,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock c.Lock() defer c.Unlock() @@ -46,7 +46,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e } // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -57,7 +57,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -75,12 +75,12 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e return nil } -func (c *Cassandra) RenewUser(username, expiration string) error { +func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error { // NOOP return nil } -func (c *Cassandra) RevokeUser(username, revocationStmts string) error { +func (c *Cassandra) RevokeUser(statements Statements, username string) error { // Grab the lock c.Lock() defer c.Unlock() diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 4c04c0fd4..e3e8cb39b 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -78,9 +78,9 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { type DatabaseType interface { Type() string - CreateUser(createStmt, rollbackStmt, username, password, expiration string) error - RenewUser(username, expiration string) error - RevokeUser(username, revocationStmt string) error + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error ConnectionProducer CredentialsProducer @@ -94,6 +94,13 @@ type DatabaseConfig struct { MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + // Query templates a query for us. func queryHelper(tpl string, data map[string]string) string { for k, v := range data { diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0cf77062c..0ff015415 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) -const defaultRevocationStmts = ` +const defaultMysqlRevocationStmts = ` REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` @@ -30,7 +30,7 @@ func (m *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock m.Lock() defer m.Unlock() @@ -49,7 +49,7 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -77,11 +77,11 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir } // NOOP -func (m *MySQL) RenewUser(username, expiration string) error { +func (m *MySQL) RenewUser(statements Statements, username, expiration string) error { return nil } -func (m *MySQL) RevokeUser(username, revocationStmts string) error { +func (m *MySQL) RevokeUser(statements Statements, username string) error { // Grab the read lock m.Lock() defer m.Unlock() @@ -92,9 +92,10 @@ func (m *MySQL) RevokeUser(username, revocationStmts string) error { return err } + revocationStmts := statements.RevocationStatements // Use a default SQL statement for revocation if one cannot be fetched from the role if revocationStmts == "" { - revocationStmts = defaultRevocationStmts + revocationStmts = defaultMysqlRevocationStmts } // Start a transaction diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 468746fc4..51b72ebc8 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -27,7 +27,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock p.Lock() defer p.Unlock() @@ -51,7 +51,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -83,7 +83,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, return nil } -func (p *PostgreSQL) RenewUser(username, expiration string) error { +func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error { // Grab the lock p.Lock() defer p.Unlock() @@ -110,16 +110,16 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { +func (p *PostgreSQL) RevokeUser(statements Statements, username string) error { // Grab the lock p.Lock() defer p.Unlock() - if revocationStmts == "" { + if statements.RevocationStatements == "" { return p.defaultRevokeUser(username) } - return p.customRevokeUser(username, revocationStmts) + return p.customRevokeUser(username, statements.RevocationStatements) } func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index b1cce97f3..3f7a513c8 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -67,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo expiration := db.GenerateExpiration(role.DefaultTTL) - err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration) + err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 994d084f0..1268b05a0 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -42,12 +43,18 @@ func pathRoles(b *databaseBackend) *framework.Path { "revocation_statements": { Type: framework.TypeString, - Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + Description: `Statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + "renew_statements": { + Type: framework.TypeString, + Description: `Statements to be executed to renew a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, }, - "rollback_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated @@ -98,9 +105,10 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statments": role.CreationStatements, - "revocation_statements": role.RevocationStatements, - "rollback_statements": role.RollbackStatements, + "creation_statments": role.Statements.CreationStatements, + "revocation_statements": role.Statements.RevocationStatements, + "rollback_statements": role.Statements.RollbackStatements, + "renew_statements": role.Statements.RenewStatements, "default_ttl": role.DefaultTTL.String(), "max_ttl": role.MaxTTL.String(), }, @@ -119,9 +127,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) dbName := data.Get("db_name").(string) + + // Get statements creationStmts := data.Get("creation_statements").(string) revocationStmts := data.Get("revocation_statements").(string) rollbackStmts := data.Get("rollback_statements").(string) + renewStmts := data.Get("renew_statements").(string) + + // Get TTLs defaultTTLRaw := data.Get("default_ttl").(string) maxTTLRaw := data.Get("max_ttl").(string) @@ -136,16 +149,21 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } + statements := dbs.Statements{ + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + RenewStatements: rollbackStmts, + } + // TODO: Think about preparing the statments to test. // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - CreationStatements: creationStmts, - RevocationStatements: revocationStmts, - RollbackStatements: rollbackStmts, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, + DBName: dbName, + Statements: statements, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -158,12 +176,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` From b7c3b4b0d757fff268ee783ffdaf991ad54e0cf3 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 17:00:52 -0800 Subject: [PATCH 014/162] Add defaults to the cassandra databse type --- builtin/logical/database/dbs/cassandra.go | 13 +++++++++++-- builtin/logical/database/dbs/db.go | 2 ++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 9956372d6..1be26766b 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -45,8 +45,17 @@ func (c *Cassandra) CreateUser(statements Statements, username, password, expira return err } + creationCQL := statements.CreationStatements + if creationCQL == "" { + creationCQL = defaultCreationCQL + } + rollbackCQL := statements.RollbackStatements + if rollbackCQL == "" { + rollbackCQL = defaultRollbackCQL + } + // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -57,7 +66,7 @@ func (c *Cassandra) CreateUser(statements Statements, username, password, expira "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e3e8cb39b..e173e2dd8 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -94,6 +94,8 @@ type DatabaseConfig struct { MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } +// Statments set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` From 2fb6bf988223a0e80e6f65160ced113ac8bf94d4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 17:21:44 -0800 Subject: [PATCH 015/162] Fix renew and revoke calls --- builtin/logical/database/path_roles.go | 2 +- builtin/logical/database/secret_creds.go | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 1268b05a0..9a5bb9324 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -153,7 +153,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, - RenewStatements: rollbackStmts, + RenewStatements: renewStmts, } // TODO: Think about preparing the statments to test. diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 120804e91..90b88082e 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -70,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { expiration := expireTime.Format("2006-01-02 15:04:05-0700") - err := db.RenewUser(username, expiration) + err := db.RenewUser(role.Statements, username, expiration) if err != nil { return nil, err } @@ -87,7 +87,6 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F } username, ok := usernameRaw.(string) - var revocationSQL string var resp *logical.Response roleNameRaw, ok := req.Secret.InternalData["role"] @@ -129,7 +128,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) } - err = db.RevokeUser(username, revocationSQL) + err = db.RevokeUser(role.Statements, username) if err != nil { return nil, err } From b7128f8370286c1b98b696a166c63addd24b4575 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 8 Mar 2017 14:46:53 -0800 Subject: [PATCH 016/162] Update secrets fields --- builtin/logical/database/dbs/mysql.go | 5 ++++ .../database/path_config_connection.go | 30 +++++++++---------- builtin/logical/database/secret_creds.go | 14 ++------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0ff015415..0d8be2a47 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -2,6 +2,7 @@ package dbs import ( "database/sql" + "fmt" "strings" "github.com/hashicorp/vault/helper/strutil" @@ -41,6 +42,10 @@ func (m *MySQL) CreateUser(statements Statements, username, password, expiration return err } + if statements.CreationStatements == "" { + return fmt.Errorf("Empty creation statements") + } + // Start a transaction tx, err := db.Begin() if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 085113fe9..c2fc085ae 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -197,22 +197,22 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew b.connections[name] = db } - /* - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } + /* TODO: + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } */ // Store it diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 90b88082e..e39525a18 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -11,18 +11,8 @@ const SecretCredsType = "creds" func secretCreds(b *databaseBackend) *framework.Secret { return &framework.Secret{ - Type: SecretCredsType, - Fields: map[string]*framework.FieldSchema{ - "username": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Username", - }, - - "password": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Password", - }, - }, + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{}, Renew: b.secretCredsRenew, Revoke: b.secretCredsRevoke, From 9099231229af1e847194e3a3d3f2bae08cdd254d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 17:43:37 -0800 Subject: [PATCH 017/162] Add plugin features --- .../logical/database/dbs/credentialsproducer.go | 10 +++++----- builtin/logical/database/dbs/db.go | 16 +++++++++++++++- .../logical/database/path_config_connection.go | 10 ++++++++++ builtin/logical/database/path_role_create.go | 5 ++++- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 94fce6275..5ae3b128e 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -11,7 +11,7 @@ import ( type CredentialsProducer interface { GenerateUsername(displayName string) (string, error) GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) string + GenerateExpiration(ttl time.Duration) (string, error) } // sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. @@ -46,10 +46,10 @@ func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { +func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { return time.Now(). Add(ttl). - Format("2006-01-02 15:04:05-0700") + Format("2006-01-02 15:04:05-0700"), nil } type cassandraCredentialsProducer struct{} @@ -74,6 +74,6 @@ func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string { - return "" +func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { + return "", nil } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e173e2dd8..063cc89cf 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -13,6 +13,7 @@ const ( postgreSQLTypeName = "postgres" mySQLTypeName = "mysql" cassandraTypeName = "cassandra" + pluginTypeName = "plugin" ) var ( @@ -71,6 +72,18 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { ConnectionProducer: connProducer, CredentialsProducer: credsProducer, }, nil + + case pluginTypeName: + if conf.PluginCommand == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand) + if err != nil { + return nil, err + } + + return db, nil } return nil, ErrUnsupportedDatabaseType @@ -82,7 +95,7 @@ type DatabaseType interface { RenewUser(statements Statements, username, expiration string) error RevokeUser(statements Statements, username string) error - ConnectionProducer + Close() CredentialsProducer } @@ -92,6 +105,7 @@ type DatabaseConfig struct { MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` } // Statments set in role creation and passed into the database type's functions. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c2fc085ae..4e1da240c 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -111,6 +111,12 @@ reduced to the same size.`, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, + + "plugin_command": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -146,6 +152,9 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) + if connType == "" { + return logical.ErrorResponse("connection_type not set"), nil + } maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -173,6 +182,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, + PluginCommand: data.Get("plugin_command").(string), } name := data.Get("name").(string) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 3f7a513c8..c7989c25d 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -65,7 +65,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, err } - expiration := db.GenerateExpiration(role.DefaultTTL) + expiration, err := db.GenerateExpiration(role.DefaultTTL) + if err != nil { + return nil, err + } err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { From 748c70cfb4d844d8dc56a05ae3b0fd9bffd8781a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 17:43:58 -0800 Subject: [PATCH 018/162] Add plugin file --- builtin/logical/database/dbs/plugin.go | 242 +++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 builtin/logical/database/dbs/plugin.go diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go new file mode 100644 index 000000000..e495dbf14 --- /dev/null +++ b/builtin/logical/database/dbs/plugin.go @@ -0,0 +1,242 @@ +package dbs + +import ( + "net/rpc" + "os/exec" + "sync" + "time" + + "github.com/hashicorp/go-plugin" +) + +// handshakeConfigs are used to just do a basic handshake between +// a plugin and host. If the handshake fails, a user friendly error is shown. +// This prevents users from executing bad plugins or executing a plugin +// directory. It is a UX feature, not a security feature. +var handshakeConfig = plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "BASIC_PLUGIN", + MagicCookieValue: "hello", +} + +type DatabasePlugin struct { + impl DatabaseType +} + +func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { + return &databasePluginRPCServer{impl: d.impl}, nil +} + +func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { + return &databasePluginRPCClient{client: c}, nil +} + +type DatabasePluginClient struct { + client *plugin.Client + sync.Mutex + + *databasePluginRPCClient +} + +func (dc *DatabasePluginClient) Close() { + dc.databasePluginRPCClient.Close() + + dc.client.Kill() +} + +func newPluginClient(command string) (DatabaseType, error) { + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": new(DatabasePlugin), + } + + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + Cmd: exec.Command(command), + }) + + // Connect via RPC + rpcClient, err := client.Client() + if err != nil { + return nil, err + } + + // Request the plugin + raw, err := rpcClient.Dispense("database") + if err != nil { + return nil, err + } + + // We should have a Greeter now! This feels like a normal interface + // implementation but is in fact over an RPC connection. + databaseRPC := raw.(*databasePluginRPCClient) + + return &DatabasePluginClient{ + client: client, + databasePluginRPCClient: databaseRPC, + }, nil +} + +func NewPluginServer(db DatabaseType) { + dbPlugin := &DatabasePlugin{ + impl: db, + } + + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": dbPlugin, + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + }) +} + +// ---- RPC client domain ---- + +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() string { + return "plugin" +} + +func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { + req := CreateUserRequest{ + Statements: statements, + Username: username, + Password: password, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { + req := RenewUserRequest{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { + req := RevokeUserRequest{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { + var username string + err := dr.client.Call("Plugin.GenerateUsername", displayName, &username) + + return username, err +} + +func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { + var password string + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, &password) + + return password, err +} + +func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { + var expiration string + err := dr.client.Call("Plugin.GenerateExpiration", duration, &expiration) + + return expiration, err +} + +// ---- RPC server domain ---- +type databasePluginRPCServer struct { + impl DatabaseType +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + *resp = "string" + return nil +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { + err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { + err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { + err := ds.impl.RevokeUser(args.Statements, args.Username) + + return err +} + +func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *string) error { + var err error + *resp, err = ds.impl.GenerateUsername(args) + + return err +} + +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *string) error { + var err error + *resp, err = ds.impl.GeneratePassword() + + return err +} + +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *string) error { + var err error + *resp, err = ds.impl.GenerateExpiration(args) + + return err +} + +// ---- Request Args domain ---- + +type CreateUserRequest struct { + Statements Statements + Username string + Password string + Expiration string +} + +type RenewUserRequest struct { + Statements Statements + Username string + Expiration string +} + +type RevokeUserRequest struct { + Statements Statements + Username string +} From fda45f531d5330acdcba68abd1f4bde9781d2678 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 21:31:29 -0800 Subject: [PATCH 019/162] Add special path to enforce root on plugin configuration --- builtin/logical/database/backend.go | 9 +- builtin/logical/database/dbs/db.go | 37 +-- .../database/path_config_connection.go | 213 ++++++++++-------- 3 files changed, 146 insertions(+), 113 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index fe853d3fb..e06e7b381 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -20,8 +20,15 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), + PathsSpecial: &logical.Paths{ + Root: []string{ + "dbs/plugin/*", + }, + }, + Paths: []*framework.Path{ - pathConfigConnection(&b), + pathConfigureConnection(&b), + pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 063cc89cf..bf78d29e6 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,7 +20,9 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf *DatabaseConfig) (DatabaseType, error) { +type Factory func(*DatabaseConfig) (DatabaseType, error) + +func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: var connProducer *sqlConnectionProducer @@ -72,23 +74,24 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { ConnectionProducer: connProducer, CredentialsProducer: credsProducer, }, nil - - case pluginTypeName: - if conf.PluginCommand == "" { - return nil, errors.New("ERROR") - } - - db, err := newPluginClient(conf.PluginCommand) - if err != nil { - return nil, err - } - - return db, nil } return nil, ErrUnsupportedDatabaseType } +func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { + if conf.PluginCommand == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand) + if err != nil { + return nil, err + } + + return db, nil +} + type DatabaseType interface { Type() string CreateUser(statements Statements, username, password, expiration string) error @@ -108,6 +111,14 @@ type DatabaseConfig struct { PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` } +func (dc *DatabaseConfig) GetFactory() Factory { + if dc.DatabaseType == pluginTypeName { + return PluginFactory + } + + return BuiltinFactory +} + // Statments set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4e1da240c..4780dc492 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -59,7 +59,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew } db.Close() - db, err = dbs.Factory(&config) + + factory := config.GetFactory() + + db, err = factory(&config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -69,9 +72,17 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -func pathConfigConnection(b *databaseBackend) *framework.Path { +func pathConfigureConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler()) +} + +func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler()) +} + +func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path { return &framework.Path{ - Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, @@ -120,8 +131,8 @@ reduced to the same size.`, }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConnectionWrite, - logical.ReadOperation: b.pathConnectionRead, + logical.UpdateOperation: updateOp, + logical.ReadOperation: readOp, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -130,115 +141,119 @@ reduced to the same size.`, } // pathConnectionRead reads out the connection configuration -func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) +func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil - } + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, err + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil } - return &logical.Response{ - Data: structs.New(config).Map(), - }, nil } -func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } - - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } - - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } - - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } - - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - } - - name := data.Get("name").(string) - - // Grab the mutex lock - b.Lock() - defer b.Unlock() - - var db dbs.DatabaseType - if _, ok := b.connections[name]; ok { - - // Don't allow the connection type to change - if b.connections[name].Type() != connType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - db, err = dbs.Factory(config) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil +func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connType := data.Get("connection_type").(string) + if connType == "" { + return logical.ErrorResponse("connection_type not set"), nil } - b.connections[name] = db - } + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } - /* TODO: - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } + + maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) + maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + "Invalid max_connection_lifetime: %s", err)), nil } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: data.Raw, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + MaxConnectionLifetime: maxConnLifetime, + PluginCommand: data.Get("plugin_command").(string), } - } - */ - // Store it - entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } + name := data.Get("name").(string) - // Reset the DB connection - resp := &logical.Response{} - resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + // Grab the mutex lock + b.Lock() + defer b.Unlock() - return resp, nil + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + } else { + db, err = factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db + } + + /* TODO: + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ + + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + + return resp, nil + } } const pathConfigConnectionHelpSyn = ` From a11911d4d44868521f5ba1e7bf010966836116de Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 22:35:45 -0800 Subject: [PATCH 020/162] Rename reset to close --- builtin/logical/database/backend.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e06e7b381..69d91f6f2 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -39,7 +39,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { secretCreds(&b), }, - Clean: b.resetAllDBs, + Clean: b.closeAllDBs, } b.logger = conf.Logger @@ -56,7 +56,7 @@ type databaseBackend struct { } // resetAllDBs closes all connections from all database types -func (b *databaseBackend) resetAllDBs() { +func (b *databaseBackend) closeAllDBs() { b.logger.Trace("postgres/resetdb: enter") defer b.logger.Trace("postgres/resetdb: exit") From 71b81aad23d62ae0a49f15995f39226403bb4d8b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 10 Mar 2017 14:10:42 -0800 Subject: [PATCH 021/162] Add checksum attribute --- builtin/logical/database/dbs/db.go | 7 ++++++- builtin/logical/database/dbs/plugin.go | 2 +- builtin/logical/database/path_config_connection.go | 7 +++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index bf78d29e6..33cf7361a 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -84,7 +84,11 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, errors.New("ERROR") } - db, err := newPluginClient(conf.PluginCommand) + if conf.PluginChecksum == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err } @@ -109,6 +113,7 @@ type DatabaseConfig struct { MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` + PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` } func (dc *DatabaseConfig) GetFactory() Factory { diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index e495dbf14..bbd8d4ce4 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -44,7 +44,7 @@ func (dc *DatabasePluginClient) Close() { dc.client.Kill() } -func newPluginClient(command string) (DatabaseType, error) { +func newPluginClient(command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4780dc492..31f618281 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -128,6 +128,12 @@ reduced to the same size.`, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, + + "plugin_checksum": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -197,6 +203,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, PluginCommand: data.Get("plugin_command").(string), + PluginChecksum: data.Get("plugin_checksum").(string), } name := data.Get("name").(string) From 2054fff89063dd3aa51c404e5bc59db73dff2efc Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 13 Mar 2017 14:39:55 -0700 Subject: [PATCH 022/162] Add a way to initalize plugins and builtin databases the same way. --- .../database/dbs/connectionproducer.go | 54 +++++++++++++++++-- builtin/logical/database/dbs/db.go | 21 ++------ builtin/logical/database/dbs/plugin.go | 12 +++++ .../database/path_config_connection.go | 24 +++++++++ 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 82da37cc7..8d05e5d9e 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -3,6 +3,7 @@ package dbs import ( "crypto/tls" "database/sql" + "errors" "fmt" "strings" "sync" @@ -11,14 +12,20 @@ import ( // Import sql drivers _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" ) +var ( + errNotInitalized = errors.New("Connection has not been initalized") +) + type ConnectionProducer interface { Close() + Initialize(map[string]interface{}) error sync.Locker connection() (interface{}, error) @@ -30,10 +37,28 @@ type sqlConnectionProducer struct { config *DatabaseConfig - db *sql.DB + initalized bool + db *sql.DB sync.Mutex } +func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *sqlConnectionProducer) connection() (interface{}, error) { // If we already have a DB, test it and return if c.db != nil { @@ -98,13 +123,34 @@ type cassandraConnectionProducer struct { TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - config *DatabaseConfig - - session *gocql.Session + config *DatabaseConfig + initalized bool + session *gocql.Session sync.Mutex } +func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *cassandraConnectionProducer) connection() (interface{}, error) { + if !c.initalized { + return nil, errNotInitalized + } + // If we already have a DB, return it if c.session != nil { return c.session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 33cf7361a..98443f8f2 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -5,8 +5,6 @@ import ( "fmt" "strings" "time" - - "github.com/mitchellh/mapstructure" ) const ( @@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var connProducer *cassandraConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &cassandraConnectionProducer{} connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} @@ -102,6 +88,7 @@ type DatabaseType interface { RenewUser(statements Statements, username, expiration string) error RevokeUser(statements Statements, username string) error + Initialize(map[string]interface{}) error Close() CredentialsProducer } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index bbd8d4ce4..b244a33fc 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st return err } +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { + err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) + + return err +} + func (dr *databasePluginRPCClient) Close() error { err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) @@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct return err } +func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { + err := ds.impl.Initialize(args) + + return err +} + func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { ds.impl.Close() return nil diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 31f618281..6c0a63a11 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,6 +3,7 @@ package database import ( "errors" "fmt" + "strings" "time" "github.com/fatih/structs" @@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + b.connections[name] = db return nil, nil @@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + verifyConnection := data.Get("verify_connection").(bool) // Grab the mutex lock b.Lock() @@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err := db.Initialize(config.ConnectionDetails) + if err != nil { + if !strings.Contains(err.Error(), "Error Initializing Connection") { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + + } + + if verifyConnection { + return logical.ErrorResponse(err.Error()), nil + + } + } + b.connections[name] = db } From 822a3eb20aa11503b48158e84d44d76bb9074419 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 14 Mar 2017 13:11:28 -0700 Subject: [PATCH 023/162] Add a metrics middleware --- .../database/dbs/connectionproducer.go | 10 +- builtin/logical/database/dbs/db.go | 27 +++- .../logical/database/dbs/metricsmiddleware.go | 145 ++++++++++++++++++ builtin/logical/database/dbs/plugin.go | 9 +- 4 files changed, 176 insertions(+), 15 deletions(-) create mode 100644 builtin/logical/database/dbs/metricsmiddleware.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 8d05e5d9e..1e944c7b9 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -24,7 +24,7 @@ var ( ) type ConnectionProducer interface { - Close() + Close() error Initialize(map[string]interface{}) error sync.Locker @@ -97,7 +97,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { return c.db, nil } -func (c *sqlConnectionProducer) Close() { +func (c *sqlConnectionProducer) Close() error { // Grab the write lock c.Lock() defer c.Unlock() @@ -107,6 +107,8 @@ func (c *sqlConnectionProducer) Close() { } c.db = nil + + return nil } type cassandraConnectionProducer struct { @@ -167,7 +169,7 @@ func (c *cassandraConnectionProducer) connection() (interface{}, error) { return session, nil } -func (c *cassandraConnectionProducer) Close() { +func (c *cassandraConnectionProducer) Close() error { // Grab the write lock c.Lock() defer c.Unlock() @@ -177,6 +179,8 @@ func (c *cassandraConnectionProducer) Close() { } c.session = nil + + return nil } func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 98443f8f2..2cc42a731 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -21,6 +21,8 @@ var ( type Factory func(*DatabaseConfig) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { + var dbType DatabaseType + switch conf.DatabaseType { case postgreSQLTypeName: connProducer := &sqlConnectionProducer{} @@ -31,10 +33,10 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { usernameLen: 63, } - return &PostgreSQL{ + dbType = &PostgreSQL{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } case mySQLTypeName: connProducer := &sqlConnectionProducer{} @@ -45,10 +47,10 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { usernameLen: 16, } - return &MySQL{ + dbType = &MySQL{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } case cassandraTypeName: connProducer := &cassandraConnectionProducer{} @@ -56,13 +58,22 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { credsProducer := &cassandraCredentialsProducer{} - return &Cassandra{ + dbType = &Cassandra{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } + + default: + return nil, ErrUnsupportedDatabaseType } - return nil, ErrUnsupportedDatabaseType + // Wrap with metrics middleware + dbType = &databaseMetricsMiddleware{ + next: dbType, + typeStr: dbType.Type(), + } + + return dbType, nil } func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { @@ -89,7 +100,7 @@ type DatabaseType interface { RevokeUser(statements Statements, username string) error Initialize(map[string]interface{}) error - Close() + Close() error CredentialsProducer } diff --git a/builtin/logical/database/dbs/metricsmiddleware.go b/builtin/logical/database/dbs/metricsmiddleware.go new file mode 100644 index 000000000..61b4bd4eb --- /dev/null +++ b/builtin/logical/database/dbs/metricsmiddleware.go @@ -0,0 +1,145 @@ +package dbs + +import ( + "time" + + metrics "github.com/armon/go-metrics" +) + +type databaseMetricsMiddleware struct { + next DatabaseType + + typeStr string +} + +func (mw *databaseMetricsMiddleware) Type() string { + return mw.next.Type() +} + +func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "CreateUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "CreateUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) + return mw.next.CreateUser(statements, username, password, expiration) +} + +func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "RenewUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "RenewUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) + return mw.next.RenewUser(statements, username, expiration) +} + +func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "RevokeUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) + return mw.next.RevokeUser(statements, username) +} + +func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "Initialize"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "Initialize"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) + return mw.next.Initialize(conf) +} + +func (mw *databaseMetricsMiddleware) Close() (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "Close"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "Close", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "Close"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) + return mw.next.Close() +} + +func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GenerateUsername"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1) + return mw.next.GenerateUsername(displayName) +} + +func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GeneratePassword"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1) + return mw.next.GeneratePassword() +} + +func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1) + return mw.next.GenerateExpiration(duration) +} diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b244a33fc..7b2b18e00 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -38,10 +38,11 @@ type DatabasePluginClient struct { *databasePluginRPCClient } -func (dc *DatabasePluginClient) Close() { - dc.databasePluginRPCClient.Close() - +func (dc *DatabasePluginClient) Close() error { + err := dc.databasePluginRPCClient.Close() dc.client.Kill() + + return err } func newPluginClient(command, checksum string) (DatabaseType, error) { @@ -179,7 +180,7 @@ type databasePluginRPCServer struct { } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = "string" + *resp = ds.impl.Type() return nil } From 3ecb344878ba38ed964ac4ddb3591dee09fe5778 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 14 Mar 2017 13:12:47 -0700 Subject: [PATCH 024/162] wrap plugin database type with metrics middleware --- builtin/logical/database/dbs/db.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2cc42a731..3b10db464 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -90,6 +90,12 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, err } + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + return db, nil } From eb6117cbb26cfc5babc9948e0fa4544000cd4a8e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Mar 2017 17:14:48 -0700 Subject: [PATCH 025/162] Work on TLS communication over plugins --- builtin/logical/database/dbs/db.go | 10 +- builtin/logical/database/dbs/plugin.go | 269 +++++++++++++++++- .../database/path_config_connection.go | 4 +- logical/system_view.go | 9 + vault/dynamic_system_view.go | 27 ++ 5 files changed, 311 insertions(+), 8 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 3b10db464..b681de360 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" "time" + + "github.com/hashicorp/vault/logical" ) const ( @@ -18,9 +20,9 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -type Factory func(*DatabaseConfig) (DatabaseType, error) +type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) -func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { +func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { var dbType DatabaseType switch conf.DatabaseType { @@ -76,7 +78,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { return dbType, nil } -func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { if conf.PluginCommand == "" { return nil, errors.New("ERROR") } @@ -85,7 +87,7 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, errors.New("ERROR") } - db, err := newPluginClient(conf.PluginCommand, conf.PluginChecksum) + db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 7b2b18e00..e4f5359a7 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,12 +1,31 @@ package dbs import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + mathrand "math/rand" "net/rpc" + "net/url" + "os" "os/exec" + "strings" "sync" "time" + "github.com/SermoDigital/jose/jws" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-plugin" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/logical" ) // handshakeConfigs are used to just do a basic handshake between @@ -45,16 +64,155 @@ func (dc *DatabasePluginClient) Close() error { return err } -func newPluginClient(command, checksum string) (DatabaseType, error) { +func generateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + // c.logger.Error("core: failed to generate replicated cluster signing key", "error", err) + return nil, nil, nil, err + } + + //c.logger.Trace("core: generating replicated cluster certificate") + + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + // 30 years of single-active uptime ought to be enough for anybody + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + // c.logger.Error("core: error generating self-signed cert for replication", "error", err) + return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) + } + + caCert, err := x509.ParseCertificate(certBytes) + if err != nil { + // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + return certBytes, caCert, key, nil +} + +func generateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + + clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) + } + + clientCert, err := x509.ParseCertificate(certBytes) + if err != nil { + // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + keyBytes, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + return nil, nil, nil, err + } + + return certBytes, clientCert, keyBytes, nil +} + +func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } + CACertBytes, CACert, CAKey, err := generateX509Cert() + if err != nil { + return nil, err + } + + clientCertBytes, clientCert, clientKey, err := generateClientCert(CACert, CAKey) + if err != nil { + return nil, err + } + + /* serverCert, serverKey, err := generateClientCert(CACert, CAKey) + if err != nil { + return nil, err + }*/ + serverKey, err := x509.MarshalECPrivateKey(CAKey) + if err != nil { + return nil, err + } + cert := tls.Certificate{ + Certificate: [][]byte{clientCertBytes, CACertBytes}, + PrivateKey: clientKey, + Leaf: clientCert, + } + + clientCertPool := x509.NewCertPool() + clientCertPool.AddCert(CACert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: clientCertPool, + ClientCAs: clientCertPool, + ServerName: CACert.Subject.CommonName, + MinVersion: tls.VersionTLS12, + } + + tlsConfig.BuildNameToCertificate() + + wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + "CACert": CACertBytes, + "ServerCert": CACertBytes, + "ServerKey": serverKey, + }, time.Second*10, true) + + cmd := exec.Command(command) + cmd.Env = append(cmd.Env, fmt.Sprintf("VAULT_WRAP_TOKEN=%s", wrapToken)) + client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - Cmd: exec.Command(command), + Cmd: cmd, + TLSConfig: tlsConfig, }) // Connect via RPC @@ -92,9 +250,116 @@ func NewPluginServer(db DatabaseType) { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, + TLSProvider: VaultPluginTLSProvider, }) } +func VaultPluginTLSProvider() (*tls.Config, error) { + unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } + + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } + + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } + + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } + + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + CABytesRaw, ok := secret.Data["CACert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + CACert, err := x509.ParseCertificate(CABytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKeyRaw, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(CACert) + + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} + // ---- RPC client domain ---- type databasePluginRPCClient struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 6c0a63a11..0a99ad196 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew factory := config.GetFactory() - db, err = factory(&config) + db, err = factory(&config, b.System()) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -231,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return logical.ErrorResponse("Can not change type of existing connection."), nil } } else { - db, err = factory(config) + db, err = factory(config, b.System()) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/logical/system_view.go b/logical/system_view.go index d769397df..56254b33a 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -1,6 +1,7 @@ package logical import ( + "errors" "time" "github.com/hashicorp/vault/helper/consts" @@ -37,6 +38,10 @@ type SystemView interface { // ReplicationState indicates the state of cluster replication ReplicationState() consts.ReplicationState + + // ResponseWrapData wraps the given data in a cubbyhole and returns the + // token used to unwrap. + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) } type StaticSystemView struct { @@ -72,3 +77,7 @@ func (d StaticSystemView) CachingDisabled() bool { func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } + +func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { + return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") +} diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 32c906fae..4c6807ace 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -87,3 +87,30 @@ func (d dynamicSystemView) ReplicationState() consts.ReplicationState { d.core.clusterParamsLock.RUnlock() return state } + +// ResponseWrapData wraps the given data in a cubbyhole and returns the +// token used to unwrap. +func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { + req := &logical.Request{ + Operation: logical.CreateOperation, + Path: "sys/init", + } + + resp := &logical.Response{ + WrapInfo: &logical.ResponseWrapInfo{ + TTL: ttl, + }, + Data: data, + } + + if jwt { + resp.WrapInfo.Format = "jwt" + } + + _, err := d.core.wrapInCubbyhole(req, resp) + if err != nil { + return "", err + } + + return resp.WrapInfo.Token, nil +} From 0a52ea5c69e1ebff833841ee9f06063b0e70fcfd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 11:55:21 -0700 Subject: [PATCH 026/162] Break tls code into helper library --- builtin/logical/database/dbs/plugin.go | 220 +------------------------ helper/pluginutil/tls.go | 218 ++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 216 deletions(-) create mode 100644 helper/pluginutil/tls.go diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index e4f5359a7..b4649fc7f 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,30 +1,16 @@ package dbs import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" - "encoding/base64" - "errors" "fmt" - "math/big" - mathrand "math/rand" "net/rpc" - "net/url" - "os" "os/exec" - "strings" "sync" "time" - "github.com/SermoDigital/jose/jws" - "github.com/hashicorp/errwrap" "github.com/hashicorp/go-plugin" - uuid "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -64,110 +50,18 @@ func (dc *DatabasePluginClient) Close() error { return err } -func generateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - // c.logger.Error("core: failed to generate replicated cluster signing key", "error", err) - return nil, nil, nil, err - } - - //c.logger.Trace("core: generating replicated cluster certificate") - - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - // 30 years of single-active uptime ought to be enough for anybody - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) - if err != nil { - // c.logger.Error("core: error generating self-signed cert for replication", "error", err) - return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) - } - - caCert, err := x509.ParseCertificate(certBytes) - if err != nil { - // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - return certBytes, caCert, key, nil -} - -func generateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageClientAuth, - x509.ExtKeyUsageServerAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - - clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) - } - - clientCert, err := x509.ParseCertificate(certBytes) - if err != nil { - // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - keyBytes, err := x509.MarshalECPrivateKey(clientKey) - if err != nil { - return nil, nil, nil, err - } - - return certBytes, clientCert, keyBytes, nil -} - func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } - CACertBytes, CACert, CAKey, err := generateX509Cert() + CACertBytes, CACert, CAKey, err := pluginutil.GenerateX509Cert() if err != nil { return nil, err } - clientCertBytes, clientCert, clientKey, err := generateClientCert(CACert, CAKey) + clientCertBytes, clientCert, clientKey, err := pluginutil.GenerateClientCert(CACert, CAKey) if err != nil { return nil, err } @@ -250,116 +144,10 @@ func NewPluginServer(db DatabaseType) { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - TLSProvider: VaultPluginTLSProvider, + TLSProvider: pluginutil.VaultPluginTLSProvider, }) } -func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } - - wt, err := jws.ParseJWT([]byte(unwrapToken)) - if err != nil { - return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) - } - if wt == nil { - return nil, errors.New("nil decoded token") - } - - addrRaw := wt.Claims().Get("addr") - if addrRaw == nil { - return nil, errors.New("decoded token does not contain primary cluster address") - } - vaultAddr, ok := addrRaw.(string) - if !ok { - return nil, errors.New("decoded token's address not valid") - } - if vaultAddr == "" { - return nil, errors.New(`no address for the vault found`) - } - - // Sanity check the value - if _, err := url.Parse(vaultAddr); err != nil { - return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) - } - - clientConf := api.DefaultConfig() - clientConf.Address = vaultAddr - client, err := api.NewClient(clientConf) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - - secret, err := client.Logical().Unwrap(unwrapToken) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - - CABytesRaw, ok := secret.Data["CACert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - CACert, err := x509.ParseCertificate(CABytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverKeyRaw, ok := secret.Data["ServerKey"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - caCertPool := x509.NewCertPool() - caCertPool.AddCert(CACert) - - cert := tls.Certificate{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverKey, - Leaf: serverCert, - } - - // Setup TLS config - tlsConfig := &tls.Config{ - ClientCAs: caCertPool, - RootCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - // TLS 1.2 minimum - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{cert}, - } - tlsConfig.BuildNameToCertificate() - - return tlsConfig, nil -} - // ---- RPC client domain ---- type databasePluginRPCClient struct { diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go new file mode 100644 index 000000000..55f27a388 --- /dev/null +++ b/helper/pluginutil/tls.go @@ -0,0 +1,218 @@ +package pluginutil + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + mathrand "math/rand" + "net/url" + "os" + "strings" + "time" + + "github.com/SermoDigital/jose/jws" + "github.com/hashicorp/errwrap" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" +) + +func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, err + } + + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + // 30 years of single-active uptime ought to be enough for anybody + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) + } + + caCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + return certBytes, caCert, key, nil +} + +func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + + clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) + } + + clientCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + keyBytes, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + return nil, nil, nil, err + } + + return certBytes, clientCert, keyBytes, nil +} + +// VaultPluginTLSProvider is run inside a plugin and retrives the response +// wrapped TLS certificate from vault. It returns a configured tlsConfig. +func VaultPluginTLSProvider() (*tls.Config, error) { + unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } + + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } + + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } + + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } + + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + CABytesRaw, ok := secret.Data["CACert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + CACert, err := x509.ParseCertificate(CABytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKeyRaw, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(CACert) + + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} From f2df4ef0e74801fe5d605c12db88ff5de8279a23 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 14:14:49 -0700 Subject: [PATCH 027/162] Comment and slight refactor of the TLS plugin helper --- builtin/logical/database/dbs/plugin.go | 45 +++---------- helper/pluginutil/tls.go | 89 +++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b4649fc7f..c068128d8 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,8 +1,6 @@ package dbs import ( - "crypto/tls" - "crypto/x509" "fmt" "net/rpc" "os/exec" @@ -56,57 +54,34 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database "database": new(DatabasePlugin), } - CACertBytes, CACert, CAKey, err := pluginutil.GenerateX509Cert() + // Get a CA TLS Certificate + CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert() if err != nil { return nil, err } - clientCertBytes, clientCert, clientKey, err := pluginutil.GenerateClientCert(CACert, CAKey) + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey) if err != nil { return nil, err } - /* serverCert, serverKey, err := generateClientCert(CACert, CAKey) - if err != nil { - return nil, err - }*/ - serverKey, err := x509.MarshalECPrivateKey(CAKey) + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey) if err != nil { return nil, err } - cert := tls.Certificate{ - Certificate: [][]byte{clientCertBytes, CACertBytes}, - PrivateKey: clientKey, - Leaf: clientCert, - } - - clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(CACert) - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: clientCertPool, - ClientCAs: clientCertPool, - ServerName: CACert.Subject.CommonName, - MinVersion: tls.VersionTLS12, - } - - tlsConfig.BuildNameToCertificate() - - wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ - "CACert": CACertBytes, - "ServerCert": CACertBytes, - "ServerKey": serverKey, - }, time.Second*10, true) + // Add the response wrap token to the ENV of the plugin cmd := exec.Command(command) - cmd.Env = append(cmd.Env, fmt.Sprintf("VAULT_WRAP_TOKEN=%s", wrapToken)) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, Cmd: cmd, - TLSConfig: tlsConfig, + TLSConfig: clientTLSConfig, }) // Connect via RPC diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 55f27a388..10ca8583a 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -21,9 +21,16 @@ import ( "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/logical" ) -func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { +var ( + PluginUnwrapTokenEnv = "VAULT_WRAP_TOKEN" +) + +// GenerateCACert returns a CA cert used to later sign the certificates for the +// plugin client and server. +func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { return nil, nil, nil, err @@ -65,7 +72,9 @@ func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { return certBytes, caCert, key, nil } -func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { +// generateSignedCert is used internally to create certificates for the plugin +// client and server. These certs are signed by the given CA Cert and Key. +func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { host, err := uuid.GenerateUUID() if err != nil { return nil, nil, nil, err @@ -101,22 +110,71 @@ func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]by return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) } - keyBytes, err := x509.MarshalECPrivateKey(clientKey) + return certBytes, clientCert, clientKey, nil +} + +// CreateClientTLSConfig creates a signed certificate and returns a configured +// TLS config. +func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (*tls.Config, error) { + clientCertBytes, clientCert, clientKey, err := generateSignedCert(CACert, CAKey) if err != nil { - return nil, nil, nil, err + return nil, err } - return certBytes, clientCert, keyBytes, nil + cert := tls.Certificate{ + Certificate: [][]byte{clientCertBytes}, + PrivateKey: clientKey, + Leaf: clientCert, + } + + clientCertPool := x509.NewCertPool() + clientCertPool.AddCert(CACert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: clientCertPool, + ClientCAs: clientCertPool, + ServerName: CACert.Subject.CommonName, + MinVersion: tls.VersionTLS12, + } + + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} + +// WrapServerConfig is used to create a server certificate and private key, then +// wrap them in an unwrap token for later retrieval by the plugin. +func WrapServerConfig(sys logical.SystemView, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { + serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) + if err != nil { + return "", err + } + rawKey, err := x509.MarshalECPrivateKey(serverKey) + if err != nil { + return "", err + } + + wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + "CACert": CACertBytes, + "ServerCert": serverCertBytes, + "ServerKey": rawKey, + }, time.Second*10, true) + + return wrapToken, err } // VaultPluginTLSProvider is run inside a plugin and retrives the response -// wrapped TLS certificate from vault. It returns a configured tlsConfig. +// wrapped TLS certificate from vault. It returns a configured TLS Config. func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + unwrapToken := os.Getenv(PluginUnwrapTokenEnv) + + // Ensure unwrap token is a JWT if strings.Count(unwrapToken, ".") != 2 { return nil, errors.New("Could not parse unwraptoken") } + // Parse the JWT and retrieve the vault address wt, err := jws.ParseJWT([]byte(unwrapToken)) if err != nil { return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) @@ -142,6 +200,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) } + // Unwrap the token clientConf := api.DefaultConfig() clientConf.Address = vaultAddr client, err := api.NewClient(clientConf) @@ -154,9 +213,10 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) } + // Retrieve and parse the CA Certificate CABytesRaw, ok := secret.Data["CACert"].(string) if !ok { - return nil, errors.New("error unmarshalling certificate") + return nil, errors.New("error unmarshalling CA certificate") } CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) @@ -169,6 +229,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, fmt.Errorf("error parsing certificate: %v", err) } + // Retrieve and parse the server's certificate serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) if !ok { return nil, errors.New("error unmarshalling certificate") @@ -184,19 +245,27 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, fmt.Errorf("error parsing certificate: %v", err) } - serverKeyRaw, ok := secret.Data["ServerKey"].(string) + // Retrieve and parse the server's private key + serverKeyB64, ok := secret.Data["ServerKey"].(string) if !ok { return nil, errors.New("error unmarshalling certificate") } - serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) if err != nil { return nil, fmt.Errorf("error parsing certificate: %v", err) } + serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Add CA cert to the cert pool caCertPool := x509.NewCertPool() caCertPool.AddCert(CACert) + // Build a certificate object out of the server's cert and private key. cert := tls.Certificate{ Certificate: [][]byte{serverCertBytes}, PrivateKey: serverKey, From d453008dea245ad8458c94348b6ff9660531d9a6 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 14:17:44 -0700 Subject: [PATCH 028/162] Update the name of PluginUnwrapTokenEnv --- helper/pluginutil/tls.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 10ca8583a..88d88689d 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -25,7 +25,9 @@ import ( ) var ( - PluginUnwrapTokenEnv = "VAULT_WRAP_TOKEN" + // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the + // plugin. + PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) // GenerateCACert returns a CA cert used to later sign the certificates for the From 287382584850f7bb51708058fbff0e73bb07be10 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 16:20:18 -0700 Subject: [PATCH 029/162] Add a secure config to verify the checksum of the plugin --- builtin/logical/database/dbs/plugin.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index c068128d8..96d28bf90 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,6 +1,8 @@ package dbs import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/rpc" "os/exec" @@ -77,11 +79,22 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database cmd := exec.Command(command) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) + checksumDecoded, err := hex.DecodeString(checksum) + if err != nil { + return nil, err + } + + secureConfig := &plugin.SecureConfig{ + Checksum: checksumDecoded, + Hash: sha256.New(), + } + client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, Cmd: cmd, TLSConfig: clientTLSConfig, + SecureConfig: secureConfig, }) // Connect via RPC From 417770a58f7f1a1a80b2c14e28ac748a376fd633 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 17:51:25 -0700 Subject: [PATCH 030/162] Change the handshake config from the default --- builtin/logical/database/dbs/plugin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 96d28bf90..8fdcb81f0 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -20,8 +20,8 @@ import ( // directory. It is a UX feature, not a security feature. var handshakeConfig = plugin.HandshakeConfig{ ProtocolVersion: 1, - MagicCookieKey: "BASIC_PLUGIN", - MagicCookieValue: "hello", + MagicCookieKey: "VAULT_DATABASE_PLUGIN", + MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } type DatabasePlugin struct { From a4e5e0f8c96d8236f627662fa8e63ec6d7c275c3 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 18:24:56 -0700 Subject: [PATCH 031/162] Comment and fix plugin Type function --- builtin/logical/database/dbs/plugin.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 8fdcb81f0..45c815a45 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -36,6 +36,8 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e return &databasePluginRPCClient{client: c}, nil } +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// method to also call Close() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client sync.Mutex @@ -50,6 +52,9 @@ func (dc *DatabasePluginClient) Close() error { return err } +// newPluginClient returns a databaseRPCClient with a connection to a running +// plugin. The client is wrapped in a DatabasePluginClient object to ensure the +// plugin is killed on call of Close(). func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ @@ -119,6 +124,9 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database }, nil } +// NewPluginServer is called from within a plugin and wraps the provided +// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// RPC server. func NewPluginServer(db DatabaseType) { dbPlugin := &DatabasePlugin{ impl: db, @@ -138,12 +146,18 @@ func NewPluginServer(db DatabaseType) { // ---- RPC client domain ---- +// databasePluginRPCClient impliments DatabaseType and is used on the client to +// make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client } func (dr *databasePluginRPCClient) Type() string { - return "plugin" + var dbType string + //TODO: catch error + dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType) } func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { @@ -216,6 +230,8 @@ func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (s } // ---- RPC server domain ---- + +// databasePluginRPCServer impliments DatabaseType and is run inside a plugin type databasePluginRPCServer struct { impl DatabaseType } From 83ff1327053a1df432c7647bae417dbfc85656f8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 21 Mar 2017 16:05:59 -0700 Subject: [PATCH 032/162] Verify connections regardless of if this connections is already existing --- .../database/path_config_connection.go | 60 +++++++------------ 1 file changed, 23 insertions(+), 37 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 0a99ad196..dc8cf34e5 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -223,53 +223,39 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - var db dbs.DatabaseType + db, err := factory(config, b.System()) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + err = db.Initialize(config.ConnectionDetails) + if err != nil { + if !strings.Contains(err.Error(), "Error Initializing Connection") { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + if verifyConnection { + return logical.ErrorResponse(err.Error()), nil + } + } + if _, ok := b.connections[name]; ok { + // Don't update connection until the reset api is hit, close for + // now. + err = db.Close() + if err != nil { + return nil, err + } // Don't allow the connection type to change if b.connections[name].Type() != connType { return logical.ErrorResponse("Can not change type of existing connection."), nil } } else { - db, err = factory(config, b.System()) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - err := db.Initialize(config.ConnectionDetails) - if err != nil { - if !strings.Contains(err.Error(), "Error Initializing Connection") { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - - } - - if verifyConnection { - return logical.ErrorResponse(err.Error()), nil - - } - } - + // Save the new connection b.connections[name] = db } - /* TODO: - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } - } - */ - // Store it entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) if err != nil { From 85ef468d46ac1a7de3e3eba40f6c604e6670e36f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 21 Mar 2017 17:19:30 -0700 Subject: [PATCH 033/162] Add a delete method --- builtin/logical/database/backend.go | 2 +- .../database/path_config_connection.go | 44 +++++++++++++++++-- builtin/logical/database/path_role_create.go | 2 +- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 69d91f6f2..610865253 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -27,7 +27,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Paths: []*framework.Path{ - pathConfigureConnection(&b), + pathConfigureBuiltinConnection(&b), pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index dc8cf34e5..3bb3a5631 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -78,15 +78,22 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -func pathConfigureConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler()) +// pathConfigureBuiltinConnection returns a configured framework.Path setup to +// operate on builtin databases. +func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) } +// pathConfigurePluginConnection returns a configured framework.Path setup to +// operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler()) + return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) } -func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path { +// buildConfigConnectionPath reutns a configured framework.Path using the passed +// in operation functions to complete the request. Used to distinguish calls +// between builtin and plugin databases. +func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework.OperationFunc) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ @@ -145,6 +152,7 @@ reduced to the same size.`, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: updateOp, logical.ReadOperation: readOp, + logical.DeleteOperation: deleteOp, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -175,6 +183,34 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { } } +// connectionDeleteHandler deletes the connection configuration +func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + err := req.Storage.Delete(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to delete connection configuration") + } + + if _, ok := b.connections[name]; ok { + err = b.connections[name].Close() + if err != nil { + return nil, err + } + } + + delete(b.connections, name) + + return nil, nil + } +} + +// connectionWriteHandler returns a handler function for creating and updating +// both builtin and plugin database types. func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index c7989c25d..14b65cbb3 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -42,7 +42,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } - // Generate the username, password and expiration. PG limits user to 63 characters + // Generate the username, password and expiration // Get our handle b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") From c55bef85d38ce8c0e1d4ff6503a500c4f484a347 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 09:54:19 -0700 Subject: [PATCH 034/162] Fix race with deleting the connection --- builtin/logical/database/path_config_connection.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 3bb3a5631..ba6a37805 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -196,6 +196,9 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return nil, fmt.Errorf("failed to delete connection configuration") } + b.Lock() + defer b.Unlock() + if _, ok := b.connections[name]; ok { err = b.connections[name].Close() if err != nil { From ae9961b81156e14d4b922ae73cea37d6d9a4ea0b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 12:40:16 -0700 Subject: [PATCH 035/162] Add a error message for empty creation statement --- builtin/logical/database/dbs/db.go | 1 + builtin/logical/database/dbs/mysql.go | 3 +-- builtin/logical/database/dbs/postgresql.go | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index b681de360..4554963ac 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -18,6 +18,7 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") + ErrEmptyCreationStatement = errors.New("Empty creation statements") ) type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0d8be2a47..54940d8f6 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -2,7 +2,6 @@ package dbs import ( "database/sql" - "fmt" "strings" "github.com/hashicorp/vault/helper/strutil" @@ -43,7 +42,7 @@ func (m *MySQL) CreateUser(statements Statements, username, password, expiration } if statements.CreationStatements == "" { - return fmt.Errorf("Empty creation statements") + return ErrEmptyCreationStatement } // Start a transaction diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 51b72ebc8..20d548f92 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -28,6 +28,10 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { } func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { + if statements.CreationStatements == "" { + return ErrEmptyCreationStatement + } + // Grab the lock p.Lock() defer p.Unlock() From dac1bb210b210b7c789edc4616a1df75d51c836f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 16:39:08 -0700 Subject: [PATCH 036/162] Add test files for postgres and mysql databases --- builtin/logical/database/dbs/mysql_test.go | 349 +++++++++++++++ .../logical/database/dbs/postgresql_test.go | 412 ++++++++++++++++++ 2 files changed, 761 insertions(+) create mode 100644 builtin/logical/database/dbs/mysql_test.go create mode 100644 builtin/logical/database/dbs/postgresql_test.go diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go new file mode 100644 index 000000000..a27dfbba7 --- /dev/null +++ b/builtin/logical/database/dbs/mysql_test.go @@ -0,0 +1,349 @@ +package dbs + +import ( + "database/sql" + "os" + "sync" + "testing" + "time" + + dockertest "gopkg.in/ory-am/dockertest.v2" +) + +var ( + testMySQLImagePull sync.Once +) + +func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("mysql") + }) + + cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + connProducer := &sqlConnectionProducer{} + connProducer.ConnectionURL = connURL + connProducer.config = &DatabaseConfig{ + DatabaseType: mySQLTypeName, + } + + conn, err := connProducer.connection() + if err != nil { + return false + } + if err := conn.(*sql.DB).Ping(); err != nil { + return false + } + + connProducer.Close() + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func TestMySQL_Initialize(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying postgres object + dbMetrics := dbRaw.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*MySQL) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestMySQL_CreateUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + statements.CreationStatements = testMySQLRoleHost + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + /* statements.CreationStatements = testBlockStatementRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + }*/ +} + +func TestMySQL_RenewUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(statements, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMySQL_RevokeUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements.CreationStatements = testMySQLRoleHost + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = testMySQLRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} + +const testMySQLRoleWildCard = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +` +const testMySQLRoleHost = ` +CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2'; +` +const testMySQLRevocationSQL = ` +REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2'; +DROP USER '{{name}}'@'10.1.1.2'; +` diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go new file mode 100644 index 000000000..211ab0254 --- /dev/null +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -0,0 +1,412 @@ +package dbs + +import ( + "database/sql" + "os" + "sync" + "testing" + "time" + + dockertest "gopkg.in/ory-am/dockertest.v2" +) + +var ( + testImagePull sync.Once +) + +func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("postgres") + }) + + cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + connProducer := &sqlConnectionProducer{} + connProducer.ConnectionURL = connURL + connProducer.config = &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + } + + conn, err := connProducer.connection() + if err != nil { + return false + } + if err := conn.(*sql.DB).Ping(); err != nil { + return false + } + + connProducer.Close() + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { + err := cid.KillRemove() + if err != nil { + t.Fatal(err) + } +} + +func TestPostgreSQL_Initialize(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying postgres object + dbMetrics := dbRaw.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*PostgreSQL) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestPostgreSQL_CreateUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + statements.CreationStatements = testReadOnlyRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + /* statements.CreationStatements = testBlockStatementRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + }*/ +} + +func TestPostgreSQL_RenewUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(statements, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPostgreSQL_RevokeUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = defaultRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; +REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + +DROP ROLE IF EXISTS "{{name}}"; +` From 106807670378748a187232322a33d982eb86dd0a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 16:44:33 -0700 Subject: [PATCH 037/162] s/postgres/mysql/ --- builtin/logical/database/dbs/mysql_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index a27dfbba7..e7edabeb7 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -15,8 +15,8 @@ var ( ) func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") + if os.Getenv("MYSQL_URL") != "" { + return "", os.Getenv("MYSQL_URL") } // Without this the checks for whether the container has started seem to @@ -75,7 +75,7 @@ func TestMySQL_Initialize(t *testing.T) { t.Fatalf("err: %s", err) } - // Deconsturct the middleware chain to get the underlying postgres object + // Deconsturct the middleware chain to get the underlying mysql object dbMetrics := dbRaw.(*databaseMetricsMiddleware) db := dbMetrics.next.(*MySQL) From c0223d888e41fbf4da5abb6a736444276d46436e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 17:09:39 -0700 Subject: [PATCH 038/162] Remove unsused code block --- builtin/logical/database/dbs/mysql_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index e7edabeb7..c489ffaea 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -185,11 +185,6 @@ func TestMySQL_CreateUser(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - /* statements.CreationStatements = testBlockStatementRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - }*/ } func TestMySQL_RenewUser(t *testing.T) { From 29ae4602dc7659cb019ebd3c4e57928c27e94f84 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 23 Mar 2017 15:54:15 -0700 Subject: [PATCH 039/162] More work on getting tests to pass --- builtin/logical/database/backend_test.go | 620 ------------------ builtin/logical/database/dbs/mysql_test.go | 2 +- builtin/logical/database/dbs/plugin.go | 9 +- builtin/logical/database/dbs/plugin_test.go | 325 +++++++++ .../logical/database/dbs/postgresql_test.go | 34 +- helper/pluginutil/tls.go | 7 +- vault/testing.go | 16 + 7 files changed, 369 insertions(+), 644 deletions(-) delete mode 100644 builtin/logical/database/backend_test.go create mode 100644 builtin/logical/database/dbs/plugin_test.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go deleted file mode 100644 index a203c9b19..000000000 --- a/builtin/logical/database/backend_test.go +++ /dev/null @@ -1,620 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "fmt" - "log" - "os" - "path" - "reflect" - "sync" - "testing" - "time" - - "github.com/hashicorp/vault/logical" - logicaltest "github.com/hashicorp/vault/logical/testing" - "github.com/lib/pq" - "github.com/mitchellh/mapstructure" - "github.com/ory-am/dockertest" -) - -var ( - testImagePull sync.Once -) - -func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testImagePull.Do(func() { - dockertest.Pull("postgres") - }) - - cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - resp, err := b.HandleRequest(&logical.Request{ - Storage: s, - Operation: logical.UpdateOperation, - Path: "config/connection", - Data: map[string]interface{}{ - "connection_url": connURL, - }, - }) - if err != nil || (resp != nil && resp.IsError()) { - // It's likely not up and running yet, so return false and try again - return false - } - if resp == nil { - t.Fatal("expected warning") - } - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { - err := cid.KillRemove() - if err != nil { - t.Fatal(err) - } -} - -func TestBackend_config_connection(t *testing.T) { - var resp *logical.Response - var err error - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - configData := map[string]interface{}{ - "connection_url": "sample_connection_url", - "value": "", - "max_open_connections": 9, - "max_idle_connections": 7, - "verify_connection": false, - } - - configReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/connection", - Storage: config.StorageView, - Data: configData, - } - resp, err = b.HandleRequest(configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(configData, "verify_connection") - if !reflect.DeepEqual(configData, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) - } -} - -func TestBackend_basic(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepReadCreds(t, b, config.StorageView, "web", connURL), - }, - }) -} - -func TestBackend_roleCrud(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepReadRole(t, "web", testRole), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web", ""), - }, - }) -} - -func TestBackend_BlockStatements(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice) - if err != nil { - t.Fatal(err) - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - // This will also validate the query - testAccStepCreateRole(t, "web-block", testBlockStatementRole, true), - testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false), - }, - }) -} - -func TestBackend_roleReadOnly(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false), - testAccStepReadRole(t, "web-readonly", testReadOnlyRole), - testAccStepCreateTable(t, b, config.StorageView, "web", connURL), - testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), - testAccStepDropTable(t, b, config.StorageView, "web", connURL), - testAccStepDeleteRole(t, "web-readonly"), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web-readonly", ""), - }, - }) -} - -func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), - testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false), - testAccStepReadRole(t, "web-readonly", testReadOnlyRole), - testAccStepCreateTable(t, b, config.StorageView, "web", connURL), - testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), - testAccStepDropTable(t, b, config.StorageView, "web", connURL), - testAccStepDeleteRole(t, "web-readonly"), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web-readonly", ""), - }, - }) -} - -func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/connection", - Data: d, - ErrorOk: true, - Check: func(resp *logical.Response) error { - if expectError { - if resp.Data == nil { - return fmt.Errorf("data is nil") - } - var e struct { - Error string `mapstructure:"error"` - } - if err := mapstructure.Decode(resp.Data, &e); err != nil { - return err - } - if len(e.Error) == 0 { - return fmt.Errorf("expected error, but write succeeded.") - } - return nil - } else if resp != nil && resp.IsError() { - return fmt.Errorf("got an error response: %v", resp.Error()) - } - return nil - }, - } -} - -func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: path.Join("roles", name), - Data: map[string]interface{}{ - "sql": sql, - }, - ErrorOk: expectFail, - } -} - -func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: path.Join("roles", name), - Data: map[string]interface{}{ - "sql": sql, - "revocation_sql": revocationSQL, - }, - ErrorOk: expectFail, - } -} - -func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: path.Join("roles", name), - } -} - -func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - returnedRows := func() int { - stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") - if err != nil { - return -1 - } - defer stmt.Close() - - rows, err := stmt.Query(d.Username) - if err != nil { - return -1 - } - defer rows.Close() - - i := 0 - for rows.Next() { - i++ - } - return i - } - - // minNumPermissions is the minimum number of permissions that will always be present. - const minNumPermissions = 2 - - userRows := returnedRows() - if userRows < minNumPermissions { - t.Fatalf("did not get expected number of rows, got %d", userRows) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - "role": name, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - userRows = returnedRows() - // User shouldn't exist so returnedRows() should encounter an error and exit with -1 - if userRows != -1 { - t.Fatalf("did not get expected number of rows, got %d", userRows) - } - - return nil - }, - } -} - -func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);") - if err != nil { - t.Fatal(err) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - return nil - }, - } -} - -func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("DROP TABLE test;") - if err != nil { - t.Fatal(err) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - return nil - }, - } -} - -func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "roles/" + name, - Check: func(resp *logical.Response) error { - if resp == nil { - if sql == "" { - return nil - } - - return fmt.Errorf("bad: %#v", resp) - } - - var d struct { - SQL string `mapstructure:"sql"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - - if d.SQL != sql { - return fmt.Errorf("bad: %#v", resp) - } - - return nil - }, - } -} - -const testRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - -const testReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - -var testBlockStatementRoleSlice = []string{ - ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ -`, - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, - `GRANT "foo-role" TO "{{name}}";`, - `ALTER ROLE "{{name}}" SET search_path = foo;`, - `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, -} - -const defaultRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; -REVOKE USAGE ON SCHEMA public FROM {{name}}; - -DROP ROLE IF EXISTS {{name}}; -` diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index c489ffaea..f4d124702 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -24,7 +24,7 @@ func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL // containers, so don't. dockertest.BindDockerToLocalhost = "yep" - testImagePull.Do(func() { + testMySQLImagePull.Do(func() { dockertest.Pull("mysql") }) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 45c815a45..1213a3677 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -6,12 +6,12 @@ import ( "fmt" "net/rpc" "os/exec" + "strings" "sync" "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/logical" ) // handshakeConfigs are used to just do a basic handshake between @@ -55,7 +55,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), @@ -81,7 +81,8 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database } // Add the response wrap token to the ENV of the plugin - cmd := exec.Command(command) + commandArr := strings.Split(command, " ") + cmd := exec.Command(commandArr[0], commandArr[1]) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) checksumDecoded, err := hex.DecodeString(checksum) @@ -265,7 +266,7 @@ func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *st return err } -func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { ds.impl.Close() return nil } diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go new file mode 100644 index 000000000..74e103c4a --- /dev/null +++ b/builtin/logical/database/dbs/plugin_test.go @@ -0,0 +1,325 @@ +package dbs + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +var ( + testPluginImagePull sync.Once +) + +type mockPlugin struct { + users map[string][]string + CredentialsProducer +} + +func (m *mockPlugin) Type() string { return "mock" } +func (m *mockPlugin) CreateUser(statements Statements, username, password, expiration string) error { + err := errors.New("err") + if username == "" || password == "" || expiration == "" { + return err + } + + if _, ok := m.users[username]; ok { + return err + } + + m.users[username] = []string{password, expiration} + + return nil +} +func (m *mockPlugin) RenewUser(statements Statements, username, expiration string) error { + err := errors.New("err") + if username == "" || expiration == "" { + return err + } + + if _, ok := m.users[username]; !ok { + return err + } + + return nil +} +func (m *mockPlugin) RevokeUser(statements Statements, username string) error { + err := errors.New("err") + if username == "" { + return err + } + + if _, ok := m.users[username]; !ok { + return err + } + + delete(m.users, username) + return nil +} +func (m *mockPlugin) Initialize(conf map[string]interface{}) error { + err := errors.New("err") + if len(conf) != 1 { + return err + } + + return nil +} +func (m *mockPlugin) Close() error { + m.users = nil + return nil +} + +func getConf(t *testing.T) *DatabaseConfig { + command := fmt.Sprintf("%s -test.run=TestPlugin_Main", os.Args[0]) + cmd := exec.Command(os.Args[0]) + hash := sha256.New() + + file, err := os.Open(cmd.Path) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + + conf := &DatabaseConfig{ + DatabaseType: pluginTypeName, + PluginCommand: command, + PluginChecksum: hex.EncodeToString(sum), + ConnectionDetails: map[string]interface{}{ + "test": true, + }, + } + + return conf +} + +func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { + core, _, _, ln := vault.TestCoreUnsealedWithListener(t) + http.TestServerWithListener(t, ln, "", core) + sys := vault.TestDynamicSystemView(core) + + return core, ln, sys +} + +func TestPlugin_Main(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + plugin := &mockPlugin{ + users: make(map[string][]string), + CredentialsProducer: &sqlCredentialsProducer{5, 50}, + } + + NewPluginServer(plugin) +} + +func TestPlugin_Initialize(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + dbRaw, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_CreateUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + // try and save the same user again to verify it saved the first time, this + // should return an error + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("expected an error, user wasn't created correctly") + } + + // Create one more user + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_RenewUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(Statements{}, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_RevokeUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(Statements{}, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try adding the same username back so we can verify it was removed + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // try once more + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.RevokeUser(Statements{}, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go index 211ab0254..dab720920 100644 --- a/builtin/logical/database/dbs/postgresql_test.go +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -11,10 +11,10 @@ import ( ) var ( - testImagePull sync.Once + testPostgresImagePull sync.Once ) -func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { +func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { if os.Getenv("PG_URL") != "" { return "", os.Getenv("PG_URL") } @@ -24,7 +24,7 @@ func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL stri // containers, so don't. dockertest.BindDockerToLocalhost = "yep" - testImagePull.Do(func() { + testPostgresImagePull.Do(func() { dockertest.Pull("postgres") }) @@ -65,7 +65,7 @@ func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { } func TestPostgreSQL_Initialize(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -107,7 +107,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { } func TestPostgreSQL_CreateUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -150,7 +150,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -172,7 +172,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - statements.CreationStatements = testReadOnlyRole + statements.CreationStatements = testPostgresReadOnlyRole err = db.CreateUser(statements, username, password, expiration) if err != nil { t.Fatalf("err: %s", err) @@ -200,7 +200,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } func TestPostgreSQL_RenewUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -237,7 +237,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -256,7 +256,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { } func TestPostgreSQL_RevokeUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -293,7 +293,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -333,7 +333,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } // Test custom revoke statements - statements.RevocationStatements = defaultRevocationSQL + statements.RevocationStatements = defaultPostgresRevocationSQL err = db.RevokeUser(statements, username) if err != nil { t.Fatalf("err: %s", err) @@ -341,7 +341,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } -const testRole = ` +const testPostgresRole = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' @@ -349,7 +349,7 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -const testReadOnlyRole = ` +const testPostgresReadOnlyRole = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' @@ -358,7 +358,7 @@ GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; ` -const testBlockStatementRole = ` +const testPostgresBlockStatementRole = ` DO $$ BEGIN IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN @@ -380,7 +380,7 @@ ALTER ROLE "{{name}}" SET search_path = foo; GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; ` -var testBlockStatementRoleSlice = []string{ +var testPostgresBlockStatementRoleSlice = []string{ ` DO $$ BEGIN @@ -403,7 +403,7 @@ $$ `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, } -const defaultRevocationSQL = ` +const defaultPostgresRevocationSQL = ` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; REVOKE USAGE ON SCHEMA public FROM "{{name}}"; diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 88d88689d..08f24985d 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -21,7 +21,6 @@ import ( "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/logical" ) var ( @@ -30,6 +29,10 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) +type Wrapper interface { + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) +} + // GenerateCACert returns a CA cert used to later sign the certificates for the // plugin client and server. func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { @@ -147,7 +150,7 @@ func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (* // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys logical.SystemView, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { +func WrapServerConfig(sys Wrapper, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) if err != nil { return "", err diff --git a/vault/testing.go b/vault/testing.go index b567fe75e..7b914bbdb 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -231,6 +231,18 @@ func TestCoreUnsealedBackend(t testing.TB, backend physical.Backend) (*Core, [][ return core, keys, token } +func TestCoreUnsealedWithListener(t testing.TB) (*Core, [][]byte, string, net.Listener) { + core, keys, token := TestCoreUnsealed(t) + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + addr := "http://" + ln.Addr().String() + core.redirectAddr = addr + + return core, keys, token, ln +} + func testTokenStore(t testing.TB, c *Core) *TokenStore { me := &MountEntry{ Table: credentialTableType, @@ -293,6 +305,10 @@ func TestKeyCopy(key []byte) []byte { return result } +func TestDynamicSystemView(c *Core) *dynamicSystemView { + return &dynamicSystemView{c, nil} +} + var testLogicalBackends = map[string]logical.Factory{} // Starts the test server which responds to SSH authentication. From 2799586f45dc94921c960730d06eccde3cb70e59 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 27 Mar 2017 11:46:20 -0700 Subject: [PATCH 040/162] Remove the unused sync.Once object --- builtin/logical/database/dbs/plugin_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go index 74e103c4a..151d0c88f 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/dbs/plugin_test.go @@ -9,7 +9,6 @@ import ( "net" "os" "os/exec" - "sync" "testing" "time" @@ -19,10 +18,6 @@ import ( "github.com/hashicorp/vault/vault" ) -var ( - testPluginImagePull sync.Once -) - type mockPlugin struct { users map[string][]string CredentialsProducer @@ -119,6 +114,8 @@ func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { return core, ln, sys } +// This is not an actual test case, it's a helper function that will be executed +// by the go-plugin client via an exec call. func TestPlugin_Main(t *testing.T) { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { return From 494f96358154b99f2650d2298e3d985944d68f1e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 27 Mar 2017 15:17:28 -0700 Subject: [PATCH 041/162] Wrap the database calls with tracing information --- ...icsmiddleware.go => databasemiddleware.go} | 104 ++++++++++++++++++ builtin/logical/database/dbs/db.go | 21 +++- builtin/logical/database/dbs/plugin.go | 7 +- builtin/logical/database/dbs/postgresql.go | 6 - .../database/path_config_connection.go | 4 +- 5 files changed, 130 insertions(+), 12 deletions(-) rename builtin/logical/database/dbs/{metricsmiddleware.go => databasemiddleware.go} (60%) diff --git a/builtin/logical/database/dbs/metricsmiddleware.go b/builtin/logical/database/dbs/databasemiddleware.go similarity index 60% rename from builtin/logical/database/dbs/metricsmiddleware.go rename to builtin/logical/database/dbs/databasemiddleware.go index 61b4bd4eb..d3f037ecb 100644 --- a/builtin/logical/database/dbs/metricsmiddleware.go +++ b/builtin/logical/database/dbs/databasemiddleware.go @@ -4,8 +4,112 @@ import ( "time" metrics "github.com/armon/go-metrics" + log "github.com/mgutz/logxi/v1" ) +// ---- Tracing Middleware Domain ---- + +type databaseTracingMiddleware struct { + next DatabaseType + logger log.Logger + + typeStr string +} + +func (mw *databaseTracingMiddleware) Type() string { + return mw.next.Type() +} + +func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) + } + return mw.next.CreateUser(statements, username, password, expiration) +} + +func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr) + } + return mw.next.RenewUser(statements, username, expiration) +} + +func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr) + } + return mw.next.RevokeUser(statements, username) +} + +func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) + } + return mw.next.Initialize(conf) +} + +func (mw *databaseTracingMiddleware) Close() (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/Close: starting", "type", mw.typeStr) + } + return mw.next.Close() +} + +func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr) + } + return mw.next.GenerateUsername(displayName) +} + +func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr) + } + return mw.next.GeneratePassword() +} + +func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr) + } + return mw.next.GenerateExpiration(duration) +} + +// ---- Metrics Middleware Domain ---- + type databaseMetricsMiddleware struct { next DatabaseType diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 4554963ac..54581e465 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" ) const ( @@ -21,9 +22,9 @@ var ( ErrEmptyCreationStatement = errors.New("Empty creation statements") ) -type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) +type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) -func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { +func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { var dbType DatabaseType switch conf.DatabaseType { @@ -76,10 +77,17 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, typeStr: dbType.Type(), } + // Wrap with tracing middleware + dbType = &databaseTracingMiddleware{ + next: dbType, + typeStr: dbType.Type(), + logger: logger, + } + return dbType, nil } -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { if conf.PluginCommand == "" { return nil, errors.New("ERROR") } @@ -99,6 +107,13 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, typeStr: db.Type(), } + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + return db, nil } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 1213a3677..b1f9abe20 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -82,7 +82,12 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database // Add the response wrap token to the ENV of the plugin commandArr := strings.Split(command, " ") - cmd := exec.Command(commandArr[0], commandArr[1]) + var cmd *exec.Cmd + if len(commandArr) > 1 { + cmd = exec.Command(commandArr[0], commandArr[1]) + } else { + cmd = exec.Command(commandArr[0]) + } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) checksumDecoded, err := hex.DecodeString(checksum) diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 20d548f92..c8ba110cf 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -43,13 +43,11 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir } // Start a transaction - // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") tx, err := db.Begin() if err != nil { return err } defer func() { - // b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction") tx.Rollback() }() // Return the secret @@ -61,7 +59,6 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir continue } - // b.logger.Trace("postgres/pathRoleCreateRead: preparing statement") stmt, err := tx.Prepare(queryHelper(query, map[string]string{ "name": username, "password": password, @@ -71,15 +68,12 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir return err } defer stmt.Close() - // b.logger.Trace("postgres/pathRoleCreateRead: executing statement") if _, err := stmt.Exec(); err != nil { return err } } // Commit the transaction - - // b.logger.Trace("postgres/pathRoleCreateRead: committing transaction") if err := tx.Commit(); err != nil { return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index ba6a37805..1b0878670 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew factory := config.GetFactory() - db, err = factory(&config, b.System()) + db, err = factory(&config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -262,7 +262,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - db, err := factory(config, b.System()) + db, err := factory(config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } From 02b0230f195bd3022d9ff4ac53f25c83c4142ed4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 10:04:42 -0700 Subject: [PATCH 042/162] Fix for checking types of database on update --- builtin/logical/database/path_config_connection.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 1b0878670..ff633e745 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -279,6 +279,8 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } if _, ok := b.connections[name]; ok { + newType := db.Type() + // Don't update connection until the reset api is hit, close for // now. err = db.Close() @@ -287,7 +289,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } // Don't allow the connection type to change - if b.connections[name].Type() != connType { + if b.connections[name].Type() != newType { return logical.ErrorResponse("Can not change type of existing connection."), nil } } else { From c50a6ebc3988a574aa42eb12e9fc16847ef397e4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 11:30:45 -0700 Subject: [PATCH 043/162] Add functionaility to build db objects from disk so restarts work --- builtin/logical/database/backend.go | 46 +++++++++++++++++-- .../database/dbs/connectionproducer.go | 8 ++-- builtin/logical/database/dbs/db.go | 5 +- .../database/path_config_connection.go | 39 +++------------- builtin/logical/database/path_role_create.go | 30 +++++------- builtin/logical/database/path_roles.go | 7 +++ builtin/logical/database/secret_creds.go | 28 +++++------ helper/pluginutil/tls.go | 3 ++ 8 files changed, 92 insertions(+), 74 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 610865253..f8bcc60f1 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -1,6 +1,7 @@ package database import ( + "fmt" "strings" "sync" @@ -52,14 +53,11 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.RWMutex + sync.Mutex } // resetAllDBs closes all connections from all database types func (b *databaseBackend) closeAllDBs() { - b.logger.Trace("postgres/resetdb: enter") - defer b.logger.Trace("postgres/resetdb: exit") - b.Lock() defer b.Unlock() @@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() { } } +// This function is used to retrieve a database object either from the cached +// connection map or by using the database config in storage. The caller of this +// function needs to hold the backend's lock. +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { + // if the object already is built and cached, return it + db, ok := b.connections[name] + if ok { + return db, nil + } + + entry, err := s.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + factory := config.GetFactory() + + db, err = factory(&config, b.System(), b.logger) + if err != nil { + return nil, err + } + + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return nil, err + } + + b.connections[name] = db + + return db, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 1e944c7b9..dae8d9400 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -20,7 +20,7 @@ import ( ) var ( - errNotInitalized = errors.New("Connection has not been initalized") + errNotInitalized = errors.New("connection has not been initalized") ) type ConnectionProducer interface { @@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er c.initalized = true if _, err := c.connection(); err != nil { - return fmt.Errorf("Error Initalizing Connection: %s", err) + return fmt.Errorf("error Initalizing Connection: %s", err) } return nil @@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { session, err := clusterConfig.CreateSession() if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) + return nil, fmt.Errorf("error creating session: %s", err) } // Set consistency @@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { // Verify the info err = session.Query(`LIST USERS`).Exec() if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) + return nil, fmt.Errorf("error validating connection info: %s", err) } return session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 54581e465..74f5a2605 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -18,10 +18,11 @@ const ( ) var ( - ErrUnsupportedDatabaseType = errors.New("Unsupported database type") - ErrEmptyCreationStatement = errors.New("Empty creation statements") + ErrUnsupportedDatabaseType = errors.New("unsupported database type") + ErrEmptyCreationStatement = errors.New("empty creation statements") ) +// Factory function for type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index ff633e745..b4c699750 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "strings" "time" @@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path { func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return nil, errors.New("No database name set") + return logical.ErrorResponse("Empty name attribute given"), nil } // Grab the mutex lock b.Lock() defer b.Unlock() - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil + db, ok := b.connections[name] + if ok { + db.Close() + delete(b.connections, name) } - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { + db, err := b.getOrCreateDBObj(req.Storage, name) + if err != nil { return nil, err } - db, ok := b.connections[name] - if !ok { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - - db.Close() - - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - err = db.Initialize(config.ConnectionDetails) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - b.connections[name] = db - return nil, nil } @@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return nil, err } - // Reset the DB connection resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 14b65cbb3..d379ef267 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { } func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - b.logger.Trace("postgres/pathRoleCreateRead: enter") - defer b.logger.Trace("postgres/pathRoleCreateRead: exit") - name := data.Get("name").(string) // Get the role - b.logger.Trace("postgres/pathRoleCreateRead: getting role") role, err := b.Role(req.Storage, name) if err != nil { return nil, err } if role == nil { - return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + } + + b.Lock() + defer b.Unlock() + + // Get the Database object + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + // TODO: return a resp error instead? + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } // Generate the username, password and expiration - - // Get our handle - b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") - - b.RLock() - defer b.RUnlock() - db, ok := b.connections[role.DBName] - if !ok { - // TODO: return a resp error instead? - return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) - } - username, err := db.GenerateUsername(req.DisplayName) if err != nil { return nil, err @@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, err } + // Create the user err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { return nil, err } - b.logger.Trace("postgres/pathRoleCreateRead: generating secret") resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ "username": username, "password": password, diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 9a5bb9324..6f62c79d9 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty role name attribute given"), nil + } + dbName := data.Get("db_name").(string) + if dbName == "" { + return logical.ErrorResponse("Empty database name attribute given"), nil + } // Get statements creationStmts := data.Get("creation_statements").(string) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index e39525a18..2b63ea1f8 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) @@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } /* TODO: think about how to handle this case. @@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F }*/ // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) } err = db.RevokeUser(role.Statements, username) diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 08f24985d..63ae2932f 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) { if err != nil { return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } // Retrieve and parse the CA Certificate CABytesRaw, ok := secret.Data["CACert"].(string) From 6b877039e73ad8d0273254214cab4cc276aa9a67 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 12:20:17 -0700 Subject: [PATCH 044/162] Update tests --- builtin/logical/database/dbs/mysql_test.go | 14 ++++++++------ builtin/logical/database/dbs/plugin_test.go | 9 +++++---- builtin/logical/database/dbs/postgresql_test.go | 14 ++++++++------ 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index f4d124702..553acc8ff 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v2" ) @@ -70,21 +71,22 @@ func TestMySQL_Initialize(t *testing.T) { }, } - dbRaw, err := BuiltinFactory(conf, nil) + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } // Deconsturct the middleware chain to get the underlying mysql object - dbMetrics := dbRaw.(*databaseMetricsMiddleware) + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) db := dbMetrics.next.(*MySQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) err = dbRaw.Initialize(conf.ConnectionDetails) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) if !connProducer.initalized { t.Fatal("Database should be initalized") } @@ -112,7 +114,7 @@ func TestMySQL_CreateUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -200,7 +202,7 @@ func TestMySQL_RenewUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -256,7 +258,7 @@ func TestMySQL_RevokeUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go index 151d0c88f..60cb6814d 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/dbs/plugin_test.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" + log "github.com/mgutz/logxi/v1" ) type mockPlugin struct { @@ -134,7 +135,7 @@ func TestPlugin_Initialize(t *testing.T) { defer ln.Close() conf := getConf(t) - dbRaw, err := PluginFactory(conf, sys) + dbRaw, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -155,7 +156,7 @@ func TestPlugin_CreateUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -209,7 +210,7 @@ func TestPlugin_RenewUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -255,7 +256,7 @@ func TestPlugin_RevokeUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go index dab720920..83aed50ba 100644 --- a/builtin/logical/database/dbs/postgresql_test.go +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v2" ) @@ -77,21 +78,22 @@ func TestPostgreSQL_Initialize(t *testing.T) { }, } - dbRaw, err := BuiltinFactory(conf, nil) + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } // Deconsturct the middleware chain to get the underlying postgres object - dbMetrics := dbRaw.(*databaseMetricsMiddleware) + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) db := dbMetrics.next.(*PostgreSQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) err = dbRaw.Initialize(conf.ConnectionDetails) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) if !connProducer.initalized { t.Fatal("Database should be initalized") } @@ -119,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -212,7 +214,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -268,7 +270,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } From b09526e1c9a338d5c356a1c401323dbdcd31a786 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 12:57:30 -0700 Subject: [PATCH 045/162] Cleanup the db factory code and add comments --- builtin/logical/database/dbs/db.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 74f5a2605..2637a73d1 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,11 +20,15 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("unsupported database type") ErrEmptyCreationStatement = errors.New("empty creation statements") + ErrEmptyPluginCommand = errors.New("empty plugin command") + ErrEmptyPluginChecksum = errors.New("empty plugin checksum") ) -// Factory function for +// Factory function definition type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) +// BuiltinFactory is used to build builtin database types. It wraps the database +// object in a logging and metrics middleware. func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { var dbType DatabaseType @@ -88,15 +92,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log return dbType, nil } +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { if conf.PluginCommand == "" { - return nil, errors.New("ERROR") + return nil, ErrEmptyPluginCommand } if conf.PluginChecksum == "" { - return nil, errors.New("ERROR") + return nil, ErrEmptyPluginChecksum } + // Make sure the database type is set to plugin + conf.DatabaseType = pluginTypeName + db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err @@ -118,6 +127,7 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logg return db, nil } +// DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() string CreateUser(statements Statements, username, password, expiration string) error @@ -129,8 +139,12 @@ type DatabaseType interface { CredentialsProducer } +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` @@ -139,6 +153,8 @@ type DatabaseConfig struct { PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` } +// GetFactory returns the appropriate factory method for the given database +// type. func (dc *DatabaseConfig) GetFactory() Factory { if dc.DatabaseType == pluginTypeName { return PluginFactory From 50729a4528a055013e344d286ce5f738fcef066f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 13:08:11 -0700 Subject: [PATCH 046/162] Add comments to connection and credential producers --- builtin/logical/database/dbs/connectionproducer.go | 7 ++++++- builtin/logical/database/dbs/credentialsproducer.go | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index dae8d9400..ca9e7250e 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -23,6 +23,9 @@ var ( errNotInitalized = errors.New("connection has not been initalized") ) +// ConnectionProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods dealing with individual database +// connections and is used in all the builtin database types. type ConnectionProducer interface { Close() error Initialize(map[string]interface{}) error @@ -31,7 +34,7 @@ type ConnectionProducer interface { connection() (interface{}, error) } -// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases +// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type sqlConnectionProducer struct { ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` @@ -111,6 +114,8 @@ func (c *sqlConnectionProducer) Close() error { return nil } +// cassandraConnectionProducer implements ConnectionProducer and provides an +// interface for cassandra databases to make connections. type cassandraConnectionProducer struct { Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Username string `json:"username" structs:"username" mapstructure:"username"` diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 5ae3b128e..6bd543f4e 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -8,20 +8,22 @@ import ( uuid "github.com/hashicorp/go-uuid" ) +// CredentialsProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods for generating user information for a +// particular database type and is used in all the builtin database types. type CredentialsProducer interface { GenerateUsername(displayName string) (string, error) GeneratePassword() (string, error) GenerateExpiration(ttl time.Duration) (string, error) } -// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. +// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. type sqlCredentialsProducer struct { displayNameLen int usernameLen int } func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { - // Generate the username, password and expiration. PG limits user to 63 characters if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { displayName = displayName[:scp.displayNameLen] } @@ -52,6 +54,8 @@ func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string Format("2006-01-02 15:04:05-0700"), nil } +// cassandraCredentialsProducer implements CredentialsProducer and provides an +// interface for cassandra databases to generate user information. type cassandraCredentialsProducer struct{} func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { From 210fa77e3c356f028c46162dbca656e261aa3116 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 14:37:57 -0700 Subject: [PATCH 047/162] fix for plugin commands that have more than one paramater --- builtin/logical/database/dbs/plugin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b1f9abe20..4bac0d16e 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -84,7 +84,7 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database commandArr := strings.Split(command, " ") var cmd *exec.Cmd if len(commandArr) > 1 { - cmd = exec.Command(commandArr[0], commandArr[1]) + cmd = exec.Command(commandArr[0], commandArr[1:]...) } else { cmd = exec.Command(commandArr[0]) } From aa15a1d3a9a62e926be7c8fd74b2e13bccdd2a06 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 3 Apr 2017 12:59:30 -0400 Subject: [PATCH 048/162] Database refactor mssql (#2562) * WIP on mssql secret backend refactor * Add RevokeUser test, and use sqlserver driver internally * Remove debug statements * Fix code comment --- .../database/dbs/connectionproducer.go | 9 +- builtin/logical/database/dbs/db.go | 17 +- builtin/logical/database/dbs/mssql.go | 219 +++++++++++++++++ builtin/logical/database/dbs/mssql_test.go | 221 ++++++++++++++++++ 4 files changed, 464 insertions(+), 2 deletions(-) create mode 100644 builtin/logical/database/dbs/mssql.go create mode 100644 builtin/logical/database/dbs/mssql_test.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index ca9e7250e..b5dc93951 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -10,6 +10,7 @@ import ( "time" // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" @@ -73,6 +74,12 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { c.db.Close() } + // For mssql backend, switch to sqlserver instead + dbType := c.config.DatabaseType + if c.config.DatabaseType == "mssql" { + dbType = "sqlserver" + } + // Otherwise, attempt to make connection conn := c.ConnectionURL @@ -86,7 +93,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { } var err error - c.db, err = sql.Open(c.config.DatabaseType, conn) + c.db, err = sql.Open(dbType, conn) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d1..cf8f8ee7f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -13,6 +13,7 @@ import ( const ( postgreSQLTypeName = "postgres" mySQLTypeName = "mysql" + msSQLTypeName = "mssql" cassandraTypeName = "cassandra" pluginTypeName = "plugin" ) @@ -61,6 +62,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log CredentialsProducer: credsProducer, } + case msSQLTypeName: + connProducer := &sqlConnectionProducer{} + connProducer.config = conf + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 10, + usernameLen: 63, + } + + dbType = &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + case cassandraTypeName: connProducer := &cassandraConnectionProducer{} connProducer.config = conf @@ -163,7 +178,7 @@ func (dc *DatabaseConfig) GetFactory() Factory { return BuiltinFactory } -// Statments set in role creation and passed into the database type's functions. +// Statements set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go new file mode 100644 index 000000000..b7439b0a8 --- /dev/null +++ b/builtin/logical/database/dbs/mssql.go @@ -0,0 +1,219 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/hashicorp/vault/helper/strutil" +) + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + ConnectionProducer + CredentialsProducer +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() string { + return msSQLTypeName +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + if statements.CreationStatements == "" { + return ErrEmptyCreationStatement + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go new file mode 100644 index 000000000..f2169299f --- /dev/null +++ b/builtin/logical/database/dbs/mssql_test.go @@ -0,0 +1,221 @@ +package dbs + +import ( + "database/sql" + "fmt" + "os" + "sync" + "testing" + "time" + + _ "github.com/denisenkom/go-mssqldb" + log "github.com/mgutz/logxi/v1" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local DynamoDB: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry, because the mssql container may not be able to accept connections yet + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying mssql object + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*MSSQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` From e8781b6a2be18c86fe06292cb175e84b02b01a33 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 3 Apr 2017 17:52:29 -0700 Subject: [PATCH 049/162] Plugin catalog --- builtin/logical/database/dbs/db.go | 17 ++- builtin/logical/database/dbs/plugin.go | 50 +-------- .../database/path_config_connection.go | 11 +- command/server.go | 13 +++ command/server/config.go | 4 +- helper/pluginutil/runner.go | 61 +++++++++++ logical/system_view.go | 7 ++ vault/core.go | 20 +++- vault/dynamic_system_view.go | 5 + vault/logical_system.go | 89 +++++++++++++++ vault/plugin_catalog.go | 101 ++++++++++++++++++ 11 files changed, 310 insertions(+), 68 deletions(-) create mode 100644 helper/pluginutil/runner.go create mode 100644 vault/plugin_catalog.go diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d1..8d44a474e 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,8 +20,7 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("unsupported database type") ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginCommand = errors.New("empty plugin command") - ErrEmptyPluginChecksum = errors.New("empty plugin checksum") + ErrEmptyPluginName = errors.New("empty plugin name") ) // Factory function definition @@ -95,18 +94,19 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginCommand == "" { - return nil, ErrEmptyPluginCommand + if conf.PluginName == "" { + return nil, ErrEmptyPluginName } - if conf.PluginChecksum == "" { - return nil, ErrEmptyPluginChecksum + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err } // Make sure the database type is set to plugin conf.DatabaseType = pluginTypeName - db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) + db, err := newPluginClient(sys, pluginMeta) if err != nil { return nil, err } @@ -149,8 +149,7 @@ type DatabaseConfig struct { MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` - PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` } // GetFactory returns the appropriate factory method for the given database diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 4bac0d16e..791f3b465 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,12 +1,8 @@ package dbs import ( - "crypto/sha256" - "encoding/hex" "fmt" "net/rpc" - "os/exec" - "strings" "sync" "time" @@ -55,59 +51,17 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } - // Get a CA TLS Certificate - CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert() + client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) if err != nil { return nil, err } - // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey) - if err != nil { - return nil, err - } - - // Use CA to sign a server cert and wrap the values in a response wrapped - // token. - wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey) - if err != nil { - return nil, err - } - - // Add the response wrap token to the ENV of the plugin - commandArr := strings.Split(command, " ") - var cmd *exec.Cmd - if len(commandArr) > 1 { - cmd = exec.Command(commandArr[0], commandArr[1:]...) - } else { - cmd = exec.Command(commandArr[0]) - } - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) - - checksumDecoded, err := hex.DecodeString(checksum) - if err != nil { - return nil, err - } - - secureConfig := &plugin.SecureConfig{ - Checksum: checksumDecoded, - Hash: sha256.New(), - } - - client := plugin.NewClient(&plugin.ClientConfig{ - HandshakeConfig: handshakeConfig, - Plugins: pluginMap, - Cmd: cmd, - TLSConfig: clientTLSConfig, - SecureConfig: secureConfig, - }) - // Connect via RPC rpcClient, err := client.Client() if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b4c699750..a0494d71e 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -112,13 +112,7 @@ reduced to the same size.`, a zero or negative value reuses connections forever.`, }, - "plugin_command": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - - "plugin_checksum": &framework.FieldSchema{ + "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, @@ -223,8 +217,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - PluginChecksum: data.Get("plugin_checksum").(string), + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) diff --git a/command/server.go b/command/server.go index 09658b949..d6eb0d76d 100644 --- a/command/server.go +++ b/command/server.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "runtime" "sort" "strconv" @@ -20,6 +21,7 @@ import ( colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" + homedir "github.com/mitchellh/go-homedir" "google.golang.org/grpc/grpclog" @@ -237,11 +239,22 @@ func (c *ServerCommand) Run(args []string) int { DefaultLeaseTTL: config.DefaultLeaseTTL, ClusterName: config.ClusterName, CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, } if dev { coreConfig.DevToken = devRootTokenID } + if config.PluginDirectory == "" { + homePath, err := homedir.Dir() + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error getting user's home directory: %v", err)) + return 1 + } + coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/") + } + var disableClustering bool // Initialize the separate HA physical backend, if it exists diff --git a/command/server/config.go b/command/server/config.go index 00edd5de9..a57fdad13 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -38,7 +38,8 @@ type Config struct { DefaultLeaseTTL time.Duration `hcl:"-"` DefaultLeaseTTLRaw string `hcl:"default_lease_ttl"` - ClusterName string `hcl:"cluster_name"` + ClusterName string `hcl:"cluster_name"` + PluginDirectory string `hcl:"plugin_directory"` } // DevConfig is a Config that is used for dev mode of Vault. @@ -339,6 +340,7 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) { "default_lease_ttl", "max_lease_ttl", "cluster_name", + "plugin_directory", } if err := checkHCLKeys(list, valid); err != nil { return nil, err diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go new file mode 100644 index 000000000..143a4c839 --- /dev/null +++ b/helper/pluginutil/runner.go @@ -0,0 +1,61 @@ +package pluginutil + +import ( + "crypto/sha256" + "fmt" + "os/exec" + + plugin "github.com/hashicorp/go-plugin" +) + +type Looker interface { + LookupPlugin(string) (*PluginRunner, error) +} + +type PluginRunner struct { + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Sha256 []byte `json:"sha256"` +} + +func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { + // Get a CA TLS Certificate + CACertBytes, CACert, CAKey, err := GenerateCACert() + if err != nil { + return nil, err + } + + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey) + if err != nil { + return nil, err + } + + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey) + if err != nil { + return nil, err + } + + // Add the response wrap token to the ENV of the plugin + cmd := exec.Command(r.Command, r.Args...) + cmd.Env = append(cmd.Env, env...) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + + secureConfig := &plugin.SecureConfig{ + Checksum: r.Sha256, + Hash: sha256.New(), + } + + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: hs, + Plugins: pluginMap, + Cmd: cmd, + TLSConfig: clientTLSConfig, + SecureConfig: secureConfig, + }) + + return client, nil +} diff --git a/logical/system_view.go b/logical/system_view.go index 56254b33a..a9626bc50 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -5,6 +5,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" ) // SystemView exposes system configuration information in a safe way @@ -42,6 +43,8 @@ type SystemView interface { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + + LookupPlugin(string) (*pluginutil.PluginRunner, error) } type StaticSystemView struct { @@ -81,3 +84,7 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") } + +func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") +} diff --git a/vault/core.go b/vault/core.go index ea378fa8a..08a828643 100644 --- a/vault/core.go +++ b/vault/core.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "path/filepath" "sync" "time" @@ -330,6 +331,12 @@ type Core struct { // uiEnabled indicates whether Vault Web UI is enabled or not uiEnabled bool + + // pluginDirectory is the location vault will look for plugins + pluginDirectory string + + // pluginCatalog is used to manage plugin configurations + pluginCatalog *PluginCatalog } // CoreConfig is used to parameterize a core @@ -374,6 +381,8 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex } @@ -453,8 +462,13 @@ func NewCore(conf *CoreConfig) (*Core, error) { } } - // Construct a new AES-GCM barrier var err error + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, fmt.Errorf("core setup failed: %v", err) + } + + // Construct a new AES-GCM barrier c.barrier, err = NewAESGCMBarrier(c.physical) if err != nil { return nil, fmt.Errorf("barrier setup failed: %v", err) @@ -1280,6 +1294,10 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupAuditedHeadersConfig(); err != nil { return err } + if err := c.setupPluginCatalog(); err != nil { + return err + } + if c.ha != nil { if err := c.startClusterListener(); err != nil { return err diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 4c6807ace..f318f3ab1 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -4,6 +4,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -114,3 +115,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim return resp.WrapInfo.Token, nil } + +func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return d.core.pluginCatalog.Get(name) +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 1c439506c..f5dbe2aff 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,6 +63,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", + "plugin-catalog", + "plugin-catalog/*", }, Unauthenticated: []string{ @@ -692,6 +694,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]), HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, + &framework.Path{ + Pattern: "plugin-catalog/(?P.+)", + + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "sha_256": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "command": &framework.FieldSchema{ + Type: framework.TypeString, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.handlePluginCatalogUpdate, + logical.DeleteOperation: b.handlePluginCatalogDelete, + logical.ReadOperation: b.handlePluginCatalogRead, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), + HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + }, }, } @@ -724,6 +750,69 @@ func (b *SystemBackend) invalidate(key string) { } } +func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + + sha256 := d.Get("sha_256").(string) + if sha256 == "" { + return logical.ErrorResponse("missing SHA-256 value"), nil + } + + command := d.Get("command").(string) + if command == "" { + return logical.ErrorResponse("missing command value"), nil + } + + sha256Bytes, err := hex.DecodeString(sha256) + if err != nil { + return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err + } + + err = b.Core.pluginCatalog.Set(pluginName, command, sha256Bytes) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + +func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + // handleAuditedHeaderUpdate creates or overwrites a header entry func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { header := d.Get("header").(string) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go new file mode 100644 index 000000000..c1f504d2c --- /dev/null +++ b/vault/plugin_catalog.go @@ -0,0 +1,101 @@ +package vault + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" +) + +var ( + pluginCatalogPrefix = "plugin-catalog/" +) + +type PluginCatalog struct { + catalogView *BarrierView + directory string + + lock sync.RWMutex + builtin map[string]*pluginutil.PluginRunner +} + +func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog { + return &PluginCatalog{ + catalogView: view.SubView(pluginCatalogPrefix), + directory: directory, + } +} + +func (c *Core) setupPluginCatalog() error { + catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory) + c.pluginCatalog = catalog + + return nil +} + +func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + out, err := c.catalogView.Get(name) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) + } + if out == nil { + return nil, fmt.Errorf("no plugin found with name: %s", name) + } + + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } + + return entry, nil +} + +func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + parts := strings.Split(command, " ") + command = parts[0] + args := parts[1:] + + command = filepath.Join(c.directory, command) + + // Best effort check to make sure the command isn't breaking out of the + // configured plugin directory. + sym, err := filepath.EvalSymlinks(command) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + symAbs, err := filepath.Abs(filepath.Dir(sym)) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + + if symAbs != c.directory { + return errors.New("can not execute files outside of configured plugin directory") + } + + entry := &pluginutil.PluginRunner{ + Name: name, + Command: command, + Args: args, + Sha256: sha256, + } + + buf, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("failed to encode plugin entry: %v", err) + } + + logicalEntry := logical.StorageEntry{ + Key: name, + Value: buf, + } + if err := c.catalogView.Put(&logicalEntry); err != nil { + return fmt.Errorf("failed to persist plugin entry: %v", err) + } + return nil +} From b506bd7790abfccadbe9f18ab7beb6be0775efa9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 3 Apr 2017 18:30:38 -0700 Subject: [PATCH 050/162] On change of configuration rotate the database type --- .../database/path_config_connection.go | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index a0494d71e..a1d32d572 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -243,29 +243,23 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } if verifyConnection { - return logical.ErrorResponse(err.Error()), nil + return logical.ErrorResponse("Could not verify connection"), nil } } if _, ok := b.connections[name]; ok { - newType := db.Type() - - // Don't update connection until the reset api is hit, close for - // now. - err = db.Close() + // Close and remove the old connection + err := b.connections[name].Close() if err != nil { return nil, err } - // Don't allow the connection type to change - if b.connections[name].Type() != newType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - // Save the new connection - b.connections[name] = db + delete(b.connections, name) } + // Save the new connection + b.connections[name] = db + // Store it entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) if err != nil { From 9dd666c7e6ad9b545694965a96735a9f9fec112b Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 4 Apr 2017 14:32:42 -0400 Subject: [PATCH 051/162] Database refactor invalidate (#2566) * WIP on invalidate function * cassandraConnectionProducer has Close() * Delete database from connections map on successful db.Close() * Move clear connection into its own func * Use const for database config path --- builtin/logical/database/backend.go | 31 +++++++++++++++++-- .../database/path_config_connection.go | 8 ++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index f8bcc60f1..4d069a432 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +const databaseConfigPath = "database/dbs/" + func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -41,6 +43,8 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Clean: b.closeAllDBs, + + Invalidate: b.invalidate, } b.logger = conf.Logger @@ -123,9 +127,32 @@ func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) return &result, nil } +func (b *databaseBackend) invalidate(key string) { + b.Lock() + defer b.Unlock() + + switch { + case strings.HasPrefix(key, databaseConfigPath): + name := strings.TrimPrefix(key, databaseConfigPath) + b.clearConnection(name) + } +} + +// clearConnection closes the database connection and +// removes it from the b.connections map. +func (b *databaseBackend) clearConnection(name string) { + db, ok := b.connections[name] + if ok { + db.Close() + delete(b.connections, name) + } +} + const backendHelp = ` -The PostgreSQL backend dynamically generates database users. +The database backend supports using many different databases +as secret backends, including but not limited to: +cassandra, msslq, mysql, postgres After mounting this backend, configure it using the endpoints within -the "config/" path. +the "database/dbs/" path. ` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index a1d32d572..be2038c31 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -40,13 +40,9 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew b.Lock() defer b.Unlock() - db, ok := b.connections[name] - if ok { - db.Close() - delete(b.connections, name) - } + b.clearConnection(name) - db, err := b.getOrCreateDBObj(req.Storage, name) + _, err := b.getOrCreateDBObj(req.Storage, name) if err != nil { return nil, err } From 305ccd54f7f8d92ac3c53f6bd91cfc4578288063 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 11:33:58 -0700 Subject: [PATCH 052/162] Don't return strings, always structs --- builtin/logical/database/dbs/plugin.go | 44 ++++++++++++++++---------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 791f3b465..441f97ca0 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -169,24 +169,24 @@ func (dr *databasePluginRPCClient) Close() error { } func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - var username string - err := dr.client.Call("Plugin.GenerateUsername", displayName, &username) + resp := &GenerateUsernameResponse{} + err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - return username, err + return resp.Username, err } func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - var password string - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, &password) + resp := &GeneratePasswordResponse{} + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - return password, err + return resp.Password, err } func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - var expiration string - err := dr.client.Call("Plugin.GenerateExpiration", duration, &expiration) + resp := &GenerateExpirationResponse{} + err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - return expiration, err + return resp.Expiration, err } // ---- RPC server domain ---- @@ -230,28 +230,28 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { return nil } -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *string) error { +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { var err error - *resp, err = ds.impl.GenerateUsername(args) + resp.Username, err = ds.impl.GenerateUsername(args) return err } -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *string) error { +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { var err error - *resp, err = ds.impl.GeneratePassword() + resp.Password, err = ds.impl.GeneratePassword() return err } -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *string) error { +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { var err error - *resp, err = ds.impl.GenerateExpiration(args) + resp.Expiration, err = ds.impl.GenerateExpiration(args) return err } -// ---- Request Args domain ---- +// ---- Request Args Domain ---- type CreateUserRequest struct { Statements Statements @@ -270,3 +270,15 @@ type RevokeUserRequest struct { Statements Statements Username string } + +// ---- Response Args Domain ---- + +type GenerateUsernameResponse struct { + Username string +} +type GenerateExpirationResponse struct { + Expiration string +} +type GeneratePasswordResponse struct { + Password string +} From 2255884a4c6199c33e983ddfd9ba41ec971f4577 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 4 Apr 2017 17:26:59 -0400 Subject: [PATCH 053/162] Do not mark conn as initialized until the end (#2567) --- builtin/logical/database/dbs/connectionproducer.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index b5dc93951..31ef2853b 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -54,12 +54,13 @@ func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { if err != nil { return err } - c.initalized = true if _, err := c.connection(); err != nil { - return fmt.Errorf("Error Initalizing Connection: %s", err) + return fmt.Errorf("error initalizing connection: %s", err) } + c.initalized = true + return nil } From 0034074691e1d357fa35319abbf92dd287648a75 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 14:43:39 -0700 Subject: [PATCH 054/162] Execute builtin plugins --- command/server.go | 63 ++++++++++++++++++++++++++++++----------- vault/core.go | 12 +++++++- vault/plugin_catalog.go | 48 +++++++++++++++++++------------ 3 files changed, 88 insertions(+), 35 deletions(-) diff --git a/command/server.go b/command/server.go index d6eb0d76d..3b1c771bb 100644 --- a/command/server.go +++ b/command/server.go @@ -1,8 +1,10 @@ package command import ( + "crypto/sha256" "encoding/base64" "fmt" + "io" "net" "net/http" "net/url" @@ -131,6 +133,33 @@ func (c *ServerCommand) Run(args []string) int { dev = true } + // Record the vault binary's location and SHA-256 checksum for use in + // builtin plugins. + ex, err := os.Executable() + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error looking up vault binary: %s", err)) + return 1 + } + + file, err := os.Open(ex) + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error loading vault binary: %s", err)) + return 1 + } + defer file.Close() + + hash := sha256.New() + _, err = io.Copy(hash, file) + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error checksumming vault binary: %s", err)) + return 1 + } + + sha256Value := hash.Sum(nil) + // Validation if !dev { switch { @@ -225,21 +254,23 @@ func (c *ServerCommand) Run(args []string) int { } coreConfig := &vault.CoreConfig{ - Physical: backend, - RedirectAddr: config.Backend.RedirectAddr, - HAPhysical: nil, - Seal: seal, - AuditBackends: c.AuditBackends, - CredentialBackends: c.CredentialBackends, - LogicalBackends: c.LogicalBackends, - Logger: c.logger, - DisableCache: config.DisableCache, - DisableMlock: config.DisableMlock, - MaxLeaseTTL: config.MaxLeaseTTL, - DefaultLeaseTTL: config.DefaultLeaseTTL, - ClusterName: config.ClusterName, - CacheSize: config.CacheSize, - PluginDirectory: config.PluginDirectory, + Physical: backend, + RedirectAddr: config.Backend.RedirectAddr, + HAPhysical: nil, + Seal: seal, + AuditBackends: c.AuditBackends, + CredentialBackends: c.CredentialBackends, + LogicalBackends: c.LogicalBackends, + Logger: c.logger, + DisableCache: config.DisableCache, + DisableMlock: config.DisableMlock, + MaxLeaseTTL: config.MaxLeaseTTL, + DefaultLeaseTTL: config.DefaultLeaseTTL, + ClusterName: config.ClusterName, + CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, + VaultBinaryLocation: ex, + VaultBinarySHA256: sha256Value, } if dev { coreConfig.DevToken = devRootTokenID @@ -252,7 +283,7 @@ func (c *ServerCommand) Run(args []string) int { "Error getting user's home directory: %v", err)) return 1 } - coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/") + coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") } var disableClustering bool diff --git a/vault/core.go b/vault/core.go index 08a828643..ffd36683b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -335,6 +335,12 @@ type Core struct { // pluginDirectory is the location vault will look for plugins pluginDirectory string + // vaultBinaryLocation is used to run builtin plugins in secure mode + vaultBinaryLocation string + + // vaultBinarySHA256 is used to run builtin plugins in secure mode + vaultBinarySHA256 []byte + // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog } @@ -381,7 +387,9 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` - PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + VaultBinaryLocation string `json:"vault_binary_location" structs:"vault_binary_location" mapstructure:"vault_binary_location"` + VaultBinarySHA256 []byte `json:"vault_binary_sha256" structs:"vault_binary_sha256" mapstructure:"vault_binary_sha256"` ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex @@ -439,6 +447,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), + vaultBinaryLocation: conf.VaultBinaryLocation, + vaultBinarySHA256: conf.VaultBinarySHA256, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index c1f504d2c..88265a245 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -10,50 +10,62 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) var ( pluginCatalogPrefix = "plugin-catalog/" + builtinPlugins = []string{"mysql-database-plugin", "postgres-database-plugin"} ) type PluginCatalog struct { - catalogView *BarrierView - directory string + catalogView *BarrierView + directory string + vaultCommand string + vaultSHA256 []byte lock sync.RWMutex builtin map[string]*pluginutil.PluginRunner } -func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog { - return &PluginCatalog{ - catalogView: view.SubView(pluginCatalogPrefix), - directory: directory, - } -} - func (c *Core) setupPluginCatalog() error { - catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory) - c.pluginCatalog = catalog + c.pluginCatalog = &PluginCatalog{ + catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + directory: c.pluginDirectory, + vaultCommand: c.vaultBinaryLocation, + vaultSHA256: c.vaultBinarySHA256, + } return nil } func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + // Look for external plugins in the barrier out, err := c.catalogView.Get(name) if err != nil { return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) } - if out == nil { + if out != nil { + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } + + return entry, nil + } + + // Look for builtin plugins + if !strutil.StrListContains(builtinPlugins, name) { return nil, fmt.Errorf("no plugin found with name: %s", name) } - entry := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %v", err) - } - - return entry, nil + return &pluginutil.PluginRunner{ + Name: name, + Command: c.vaultCommand, + Args: []string{"plugin-exec", name}, + Sha256: c.vaultSHA256, + }, nil } func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { From 11abcd52e696c5163565b8bf4764af0c4ac71b87 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 17:12:02 -0700 Subject: [PATCH 055/162] Add a cli command to run builtin plugins --- cli/commands.go | 6 ++++ command/plugin-exec.go | 71 +++++++++++++++++++++++++++++++++++++++++ vault/plugin_catalog.go | 3 +- 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 command/plugin-exec.go diff --git a/cli/commands.go b/cli/commands.go index 13f7c8b25..e7545ca90 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -331,5 +331,11 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { Ui: metaPtr.Ui, }, nil }, + + "plugin-exec": func() (cli.Command, error) { + return &command.PluginExec{ + Meta: *metaPtr, + }, nil + }, } } diff --git a/command/plugin-exec.go b/command/plugin-exec.go new file mode 100644 index 000000000..18dc3e145 --- /dev/null +++ b/command/plugin-exec.go @@ -0,0 +1,71 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/hashicorp/vault/meta" +) + +type PluginExec struct { + meta.Meta +} + +var builtinFactories = map[string]func() error{ +// "mysql-database-plugin": mysql.Factory, +// "postgres-database-plugin": postgres.Factory, +} + +func (c *PluginExec) Run(args []string) int { + flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) + flags.Usage = func() { c.Ui.Error(c.Help()) } + if err := flags.Parse(args); err != nil { + return 1 + } + + args = flags.Args() + if len(args) != 1 { + flags.Usage() + c.Ui.Error(fmt.Sprintf( + "\nplugin-exec expects one argument: the plugin to execute.")) + return 1 + } + + pluginName := args[0] + + factory, ok := builtinFactories[pluginName] + if !ok { + c.Ui.Error(fmt.Sprintf( + "No plugin with the name %s found", pluginName)) + return 1 + } + + err := factory() + if err != nil { + c.Ui.Error(fmt.Sprintf( + "Error running plugin: %s", err)) + return 1 + } + + return 0 +} + +func (c *PluginExec) Synopsis() string { + return "Force the Vault node to give up active duty" +} + +func (c *PluginExec) Help() string { + helpText := ` +Usage: vault step-down [options] + + Force the Vault node to step down from active duty. + + This causes the indicated node to give up active status. Note that while the + affected node will have a short delay before attempting to grab the lock + again, if no other node grabs the lock beforehand, it is possible for the + same node to re-grab the lock and become active again. + +General Options: +` + meta.GeneralOptionsUsage() + return strings.TrimSpace(helpText) +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 88265a245..eccac2bd1 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -25,8 +25,7 @@ type PluginCatalog struct { vaultCommand string vaultSHA256 []byte - lock sync.RWMutex - builtin map[string]*pluginutil.PluginRunner + lock sync.RWMutex } func (c *Core) setupPluginCatalog() error { From b071144c67341c19072c4b93db5a491dace4566d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 11:00:13 -0700 Subject: [PATCH 056/162] move builtin plugins list to the pluginutil --- command/plugin-exec.go | 23 +++++++++-------------- helper/pluginutil/builtin.go | 6 ++++++ vault/plugin_catalog.go | 4 +--- 3 files changed, 16 insertions(+), 17 deletions(-) create mode 100644 helper/pluginutil/builtin.go diff --git a/command/plugin-exec.go b/command/plugin-exec.go index 18dc3e145..f0d6a8d51 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/meta" ) @@ -11,11 +12,6 @@ type PluginExec struct { meta.Meta } -var builtinFactories = map[string]func() error{ -// "mysql-database-plugin": mysql.Factory, -// "postgres-database-plugin": postgres.Factory, -} - func (c *PluginExec) Run(args []string) int { flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } @@ -33,14 +29,14 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - factory, ok := builtinFactories[pluginName] + runner, ok := pluginutil.BuiltinPlugins[pluginName] if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) return 1 } - err := factory() + err := runner() if err != nil { c.Ui.Error(fmt.Sprintf( "Error running plugin: %s", err)) @@ -51,19 +47,18 @@ func (c *PluginExec) Run(args []string) int { } func (c *PluginExec) Synopsis() string { - return "Force the Vault node to give up active duty" + return "Runs a builtin plugin. Should only be called by vault." } func (c *PluginExec) Help() string { helpText := ` -Usage: vault step-down [options] +Usage: vault plugin-exec type - Force the Vault node to step down from active duty. + Runs a builtin plugin. Should only be called by vault. - This causes the indicated node to give up active status. Note that while the - affected node will have a short delay before attempting to grab the lock - again, if no other node grabs the lock beforehand, it is possible for the - same node to re-grab the lock and become active again. + This will execute a plugin for use in a plugable location in vault. If run by + a cli user it will print a message indicating it can not be executed by anyone + other than vault. For supported plugin types see the vault documentation. General Options: ` + meta.GeneralOptionsUsage() diff --git a/helper/pluginutil/builtin.go b/helper/pluginutil/builtin.go new file mode 100644 index 000000000..6a464bb82 --- /dev/null +++ b/helper/pluginutil/builtin.go @@ -0,0 +1,6 @@ +package pluginutil + +var BuiltinPlugins = map[string]func() error{ +// "mysql-database-plugin": mysql.Run, +// "postgres-database-plugin": postgres.Run, +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index eccac2bd1..c6e4e4059 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -10,13 +10,11 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) var ( pluginCatalogPrefix = "plugin-catalog/" - builtinPlugins = []string{"mysql-database-plugin", "postgres-database-plugin"} ) type PluginCatalog struct { @@ -55,7 +53,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if !strutil.StrListContains(builtinPlugins, name) { + if _, ok := pluginutil.BuiltinPlugins[name]; !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } From ca2c3d0c531c8f6ff7d232d92f5badf925992110 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 16:20:31 -0700 Subject: [PATCH 057/162] Refactor to use builtin plugins from an external repo --- builtin/logical/database/backend.go | 51 ++- .../database/{dbs => }/databasemiddleware.go | 2 +- builtin/logical/database/dbs/cassandra.go | 108 ----- .../database/dbs/connectionproducer.go | 280 ------------ .../database/dbs/credentialsproducer.go | 83 ---- builtin/logical/database/dbs/db.go | 196 --------- builtin/logical/database/dbs/mssql.go | 219 --------- builtin/logical/database/dbs/mssql_test.go | 221 ---------- builtin/logical/database/dbs/mysql.go | 135 ------ builtin/logical/database/dbs/mysql_test.go | 346 --------------- builtin/logical/database/dbs/postgresql.go | 279 ------------ .../logical/database/dbs/postgresql_test.go | 414 ------------------ .../database/path_config_connection.go | 78 +--- builtin/logical/database/path_roles.go | 11 +- builtin/logical/database/{dbs => }/plugin.go | 44 +- .../logical/database/{dbs => }/plugin_test.go | 2 +- command/plugin-exec.go | 4 +- helper/builtinplugins/builtin.go | 8 + helper/pluginutil/builtin.go | 6 - vault/plugin_catalog.go | 3 +- 20 files changed, 110 insertions(+), 2380 deletions(-) rename builtin/logical/database/{dbs => }/databasemiddleware.go (99%) delete mode 100644 builtin/logical/database/dbs/cassandra.go delete mode 100644 builtin/logical/database/dbs/connectionproducer.go delete mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/dbs/db.go delete mode 100644 builtin/logical/database/dbs/mssql.go delete mode 100644 builtin/logical/database/dbs/mssql_test.go delete mode 100644 builtin/logical/database/dbs/mysql.go delete mode 100644 builtin/logical/database/dbs/mysql_test.go delete mode 100644 builtin/logical/database/dbs/postgresql.go delete mode 100644 builtin/logical/database/dbs/postgresql_test.go rename builtin/logical/database/{dbs => }/plugin.go (88%) rename builtin/logical/database/{dbs => }/plugin_test.go (99%) create mode 100644 helper/builtinplugins/builtin.go delete mode 100644 helper/pluginutil/builtin.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4d069a432..a2fff4ba8 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -4,16 +4,52 @@ import ( "fmt" "strings" "sync" + "time" log "github.com/mgutz/logxi/v1" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) const databaseConfigPath = "database/dbs/" +// DatabaseType is the interface that all database objects must implement. +type DatabaseType interface { + Type() string + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error + + Initialize(map[string]interface{}) error + Close() error + + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) (string, error) +} + +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` +} + +// Statements set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -30,7 +66,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Paths: []*framework.Path{ - pathConfigureBuiltinConnection(&b), pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), @@ -48,12 +83,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbs.DatabaseType) + b.connections = make(map[string]DatabaseType) return &b } type databaseBackend struct { - connections map[string]dbs.DatabaseType + connections map[string]DatabaseType logger log.Logger *framework.Backend @@ -73,7 +108,7 @@ func (b *databaseBackend) closeAllDBs() { // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { @@ -88,14 +123,12 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs. return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) + db, err = PluginFactory(&config, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/databasemiddleware.go b/builtin/logical/database/databasemiddleware.go similarity index 99% rename from builtin/logical/database/dbs/databasemiddleware.go rename to builtin/logical/database/databasemiddleware.go index d3f037ecb..5892e8064 100644 --- a/builtin/logical/database/dbs/databasemiddleware.go +++ b/builtin/logical/database/databasemiddleware.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "time" diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go deleted file mode 100644 index 1be26766b..000000000 --- a/builtin/logical/database/dbs/cassandra.go +++ /dev/null @@ -1,108 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/strutil" -) - -const ( - defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultRollbackCQL = `DROP USER '{{username}}';` -) - -type Cassandra struct { - // Session is goroutine safe, however, since we reinitialize - // it when connection info changes, we want to make sure we - // can close it and use a new connection; hence the lock - ConnectionProducer - CredentialsProducer -} - -func (c *Cassandra) Type() string { - return cassandraTypeName -} - -func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.connection() - if err != nil { - return nil, err - } - - return session.(*gocql.Session), nil -} - -func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - // Get the connection - session, err := c.getConnection() - if err != nil { - return err - } - - creationCQL := statements.CreationStatements - if creationCQL == "" { - creationCQL = defaultCreationCQL - } - rollbackCQL := statements.RollbackStatements - if rollbackCQL == "" { - rollbackCQL = defaultRollbackCQL - } - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - err = session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - } - return err - } - } - - return nil -} - -func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -func (c *Cassandra) RevokeUser(statements Statements, username string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - session, err := c.getConnection() - if err != nil { - return err - } - - err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() - if err != nil { - return fmt.Errorf("error removing user %s", username) - } - - return nil -} diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go deleted file mode 100644 index 31ef2853b..000000000 --- a/builtin/logical/database/dbs/connectionproducer.go +++ /dev/null @@ -1,280 +0,0 @@ -package dbs - -import ( - "crypto/tls" - "database/sql" - "errors" - "fmt" - "strings" - "sync" - "time" - - // Import sql drivers - _ "github.com/denisenkom/go-mssqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/mitchellh/mapstructure" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" -) - -var ( - errNotInitalized = errors.New("connection has not been initalized") -) - -// ConnectionProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods dealing with individual database -// connections and is used in all the builtin database types. -type ConnectionProducer interface { - Close() error - Initialize(map[string]interface{}) error - - sync.Locker - connection() (interface{}, error) -} - -// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - - config *DatabaseConfig - - initalized bool - db *sql.DB - sync.Mutex -} - -func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error initalizing connection: %s", err) - } - - c.initalized = true - - return nil -} - -func (c *sqlConnectionProducer) connection() (interface{}, error) { - // If we already have a DB, test it and return - if c.db != nil { - if err := c.db.Ping(); err == nil { - return c.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - c.db.Close() - } - - // For mssql backend, switch to sqlserver instead - dbType := c.config.DatabaseType - if c.config.DatabaseType == "mssql" { - dbType = "sqlserver" - } - - // Otherwise, attempt to make connection - conn := c.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } - - var err error - c.db, err = sql.Open(dbType, conn) - if err != nil { - return nil, err - } - - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - c.db.SetMaxOpenConns(c.config.MaxOpenConnections) - c.db.SetMaxIdleConns(c.config.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) - - return c.db, nil -} - -func (c *sqlConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.db != nil { - c.db.Close() - } - - c.db = nil - - return nil -} - -// cassandraConnectionProducer implements ConnectionProducer and provides an -// interface for cassandra databases to make connections. -type cassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - - config *DatabaseConfig - initalized bool - session *gocql.Session - sync.Mutex -} - -func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - c.initalized = true - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error Initalizing Connection: %s", err) - } - - return nil -} - -func (c *cassandraConnectionProducer) connection() (interface{}, error) { - if !c.initalized { - return nil, errNotInitalized - } - - // If we already have a DB, return it - if c.session != nil { - return c.session, nil - } - - session, err := c.createSession() - if err != nil { - return nil, err - } - - // Store the session in backend for reuse - c.session = session - - return session, nil -} - -func (c *cassandraConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.session != nil { - c.session.Close() - } - - c.session = nil - - return nil -} - -func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: c.Username, - Password: c.Password, - } - - clusterConfig.ProtoVersion = c.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - - if c.TLS { - var tlsConfig *tls.Config - if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { - if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(c.Certificate) > 0 { - certBundle.Certificate = c.Certificate - certBundle.PrivateKey = c.PrivateKey - } - if len(c.IssuingCA) > 0 { - certBundle.IssuingCA = c.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = c.InsecureTLS - - if c.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() - if err != nil { - return nil, fmt.Errorf("error creating session: %s", err) - } - - // Set consistency - if c.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) - if err != nil { - return nil, err - } - - session.SetConsistency(consistencyValue) - } - - // Verify the info - err = session.Query(`LIST USERS`).Exec() - if err != nil { - return nil, fmt.Errorf("error validating connection info: %s", err) - } - - return session, nil -} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go deleted file mode 100644 index 6bd543f4e..000000000 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ /dev/null @@ -1,83 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - "time" - - uuid "github.com/hashicorp/go-uuid" -) - -// CredentialsProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods for generating user information for a -// particular database type and is used in all the builtin database types. -type CredentialsProducer interface { - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) -} - -// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. -type sqlCredentialsProducer struct { - displayNameLen int - usernameLen int -} - -func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { - if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { - displayName = displayName[:scp.displayNameLen] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scp.usernameLen > 0 && len(username) > scp.usernameLen { - username = username[:scp.usernameLen] - } - - return username, nil -} - -func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return time.Now(). - Add(ttl). - Format("2006-01-02 15:04:05-0700"), nil -} - -// cassandraCredentialsProducer implements CredentialsProducer and provides an -// interface for cassandra databases to generate user information. -type cassandraCredentialsProducer struct{} - -func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) - username = strings.Replace(username, "-", "_", -1) - - return username, nil -} - -func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return "", nil -} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go deleted file mode 100644 index 49b18b3b8..000000000 --- a/builtin/logical/database/dbs/db.go +++ /dev/null @@ -1,196 +0,0 @@ -package dbs - -import ( - "errors" - "fmt" - "strings" - "time" - - "github.com/hashicorp/vault/logical" - log "github.com/mgutz/logxi/v1" -) - -const ( - postgreSQLTypeName = "postgres" - mySQLTypeName = "mysql" - msSQLTypeName = "mssql" - cassandraTypeName = "cassandra" - pluginTypeName = "plugin" -) - -var ( - ErrUnsupportedDatabaseType = errors.New("unsupported database type") - ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginName = errors.New("empty plugin name") -) - -// Factory function definition -type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) - -// BuiltinFactory is used to build builtin database types. It wraps the database -// object in a logging and metrics middleware. -func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - var dbType DatabaseType - - switch conf.DatabaseType { - case postgreSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 23, - usernameLen: 63, - } - - dbType = &PostgreSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case mySQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 4, - usernameLen: 16, - } - - dbType = &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case msSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 10, - usernameLen: 63, - } - - dbType = &MSSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case cassandraTypeName: - connProducer := &cassandraConnectionProducer{} - connProducer.config = conf - - credsProducer := &cassandraCredentialsProducer{} - - dbType = &Cassandra{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - default: - return nil, ErrUnsupportedDatabaseType - } - - // Wrap with metrics middleware - dbType = &databaseMetricsMiddleware{ - next: dbType, - typeStr: dbType.Type(), - } - - // Wrap with tracing middleware - dbType = &databaseTracingMiddleware{ - next: dbType, - typeStr: dbType.Type(), - logger: logger, - } - - return dbType, nil -} - -// PluginFactory is used to build plugin database types. It wraps the database -// object in a logging and metrics middleware. -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginName == "" { - return nil, ErrEmptyPluginName - } - - pluginMeta, err := sys.LookupPlugin(conf.PluginName) - if err != nil { - return nil, err - } - - // Make sure the database type is set to plugin - conf.DatabaseType = pluginTypeName - - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err - } - - // Wrap with metrics middleware - db = &databaseMetricsMiddleware{ - next: db, - typeStr: db.Type(), - } - - // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: db.Type(), - logger: logger, - } - - return db, nil -} - -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { - Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error - RevokeUser(statements Statements, username string) error - - Initialize(map[string]interface{}) error - Close() error - CredentialsProducer -} - -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` -} - -// GetFactory returns the appropriate factory method for the given database -// type. -func (dc *DatabaseConfig) GetFactory() Factory { - if dc.DatabaseType == pluginTypeName { - return PluginFactory - } - - return BuiltinFactory -} - -// Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` -} - -// Query templates a query for us. -func queryHelper(tpl string, data map[string]string) string { - for k, v := range data { - tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) - } - - return tpl -} diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go deleted file mode 100644 index b7439b0a8..000000000 --- a/builtin/logical/database/dbs/mssql.go +++ /dev/null @@ -1,219 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -// MSSQL is an implementation of DatabaseType interface -type MSSQL struct { - ConnectionProducer - CredentialsProducer -} - -// Type returns the TypeName for this backend -func (m *MSSQL) Type() string { - return msSQLTypeName -} - -func (m *MSSQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by -// the CreationStatement provided. -func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -// RevokeUser attempts to drop the specified user. It will first attempt to disable login, -// then kill pending connections from that user, and finally drop the user and login from the -// database instance. -func (m *MSSQL) RevokeUser(statements Statements, username string) error { - // Get connection - db, err := m.getConnection() - if err != nil { - return err - } - - // First disable server login - disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) - if err != nil { - return err - } - defer disableStmt.Close() - if _, err := disableStmt.Exec(); err != nil { - return err - } - - // Query for sessions for the login so that we can kill any outstanding - // sessions. There cannot be any active sessions before we drop the logins - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - sessionStmt, err := db.Prepare(fmt.Sprintf( - "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) - if err != nil { - return err - } - defer sessionStmt.Close() - - sessionRows, err := sessionStmt.Query() - if err != nil { - return err - } - defer sessionRows.Close() - - var revokeStmts []string - for sessionRows.Next() { - var sessionID int - err = sessionRows.Scan(&sessionID) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) - } - - // Query for database users using undocumented stored procedure for now since - // it is the easiest way to get this information; - // we need to drop the database users before we can drop the login and the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query() - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var loginName, dbName, qUsername string - var aliasName sql.NullString - err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) - } - - // we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revokeStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all database users are dropped - if rows.Err() != nil { - return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) - } - - // Drop this login - stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -const dropUserSQL = ` -USE [%s] -IF EXISTS - (SELECT name - FROM sys.database_principals - WHERE name = N'%s') -BEGIN - DROP USER [%s] -END -` - -const dropLoginSQL = ` -IF EXISTS - (SELECT name - FROM master.sys.server_principals - WHERE name = N'%s') -BEGIN - DROP LOGIN [%s] -END -` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go deleted file mode 100644 index f2169299f..000000000 --- a/builtin/logical/database/dbs/mssql_test.go +++ /dev/null @@ -1,221 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "os" - "sync" - "testing" - "time" - - _ "github.com/denisenkom/go-mssqldb" - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v3" -) - -var ( - testMSQLImagePull sync.Once -) - -func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { - if os.Getenv("MSSQL_URL") != "" { - return func() {}, os.Getenv("MSSQL_URL") - } - - pool, err := dockertest.NewPool("") - if err != nil { - t.Fatalf("Failed to connect to docker: %s", err) - } - - resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) - if err != nil { - t.Fatalf("Could not start local MSSQL docker container: %s", err) - } - - cleanup = func() { - err := pool.Purge(resource) - if err != nil { - t.Fatalf("Failed to cleanup local DynamoDB: %s", err) - } - } - - retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) - - // exponential backoff-retry, because the mssql container may not be able to accept connections yet - if err = pool.Retry(func() error { - var err error - var db *sql.DB - db, err = sql.Open("mssql", retURL) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - t.Fatalf("Could not connect to MSSQL docker container: %s", err) - } - - return -} - -func TestMSSQL_Initialize(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mssql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MSSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMSSQL_RevokeUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -const testMSSQLRole = ` -CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; -CREATE USER [{{name}}] FOR LOGIN [{{name}}]; -GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go deleted file mode 100644 index 54940d8f6..000000000 --- a/builtin/logical/database/dbs/mysql.go +++ /dev/null @@ -1,135 +0,0 @@ -package dbs - -import ( - "database/sql" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -const defaultMysqlRevocationStmts = ` - REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; - DROP USER '{{name}}'@'%' -` - -type MySQL struct { - ConnectionProducer - CredentialsProducer -} - -func (m *MySQL) Type() string { - return mySQLTypeName -} - -func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// NOOP -func (m *MySQL) RenewUser(statements Statements, username, expiration string) error { - return nil -} - -func (m *MySQL) RevokeUser(statements Statements, username string) error { - // Grab the read lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - revocationStmts := statements.RevocationStatements - // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { - revocationStmts = defaultMysqlRevocationStmts - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - // This is not a prepared statement because not all commands are supported - // 1295: This command is not supported in the prepared statement protocol yet - // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ - query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.Exec(query) - if err != nil { - return err - } - - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go deleted file mode 100644 index 553acc8ff..000000000 --- a/builtin/logical/database/dbs/mysql_test.go +++ /dev/null @@ -1,346 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testMySQLImagePull sync.Once -) - -func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("MYSQL_URL") != "" { - return "", os.Getenv("MYSQL_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testMySQLImagePull.Do(func() { - dockertest.Pull("mysql") - }) - - cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: mySQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func TestMySQL_Initialize(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mysql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MySQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMySQL_CreateUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RenewUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RevokeUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = testMySQLRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testMySQLRoleWildCard = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'%'; -` -const testMySQLRoleHost = ` -CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2'; -` -const testMySQLRevocationSQL = ` -REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2'; -DROP USER '{{name}}'@'10.1.1.2'; -` diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go deleted file mode 100644 index c8ba110cf..000000000 --- a/builtin/logical/database/dbs/postgresql.go +++ /dev/null @@ -1,279 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" - "github.com/lib/pq" -) - -type PostgreSQL struct { - ConnectionProducer - CredentialsProducer -} - -func (p *PostgreSQL) Type() string { - return postgreSQLTypeName -} - -func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Grab the lock - p.Lock() - defer p.Unlock() - - // Get the connection - db, err := p.getConnection() - if err != nil { - return err - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - // Return the secret - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expiration, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - db, err := p.getConnection() - if err != nil { - return err - } - - query := fmt.Sprintf( - "ALTER ROLE %s VALID UNTIL '%s';", - pq.QuoteIdentifier(username), - expiration) - - stmt, err := db.Prepare(query) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RevokeUser(statements Statements, username string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - if statements.RevocationStatements == "" { - return p.defaultRevokeUser(username) - } - - return p.customRevokeUser(username, statements.RevocationStatements) -} - -func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() - - if _, err := stmt.Exec(); err != nil { - return err - } - } - - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) defaultRevokeUser(username string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - // Check if the role exists - var exists bool - err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) - if err != nil && err != sql.ErrNoRows { - return err - } - - if exists == false { - return nil - } - - // Query for permissions; we need to revoke permissions before we can drop - // the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query(username) - if err != nil { - return err - } - defer rows.Close() - - const initialNumRevocations = 16 - revocationStmts := make([]string, 0, initialNumRevocations) - for rows.Next() { - var schema string - err = rows.Scan(&schema) - if err != nil { - // keep going; remove as many permissions as possible right now - continue - } - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE USAGE ON SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - } - - // for good measure, revoke all privileges and usage on schema public - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE USAGE ON SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - // get the current database name so we can issue a REVOKE CONNECT for - // this username - var dbname sql.NullString - if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { - return err - } - - if dbname.Valid { - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE CONNECT ON DATABASE %s FROM %s;`, - pq.QuoteIdentifier(dbname.String), - pq.QuoteIdentifier(username))) - } - - // again, here, we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revocationStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all privileges are revoked - if rows.Err() != nil { - return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) - } - - // Drop this user - stmt, err = db.Prepare(fmt.Sprintf( - `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go deleted file mode 100644 index 83aed50ba..000000000 --- a/builtin/logical/database/dbs/postgresql_test.go +++ /dev/null @@ -1,414 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testPostgresImagePull sync.Once -) - -func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testPostgresImagePull.Do(func() { - dockertest.Pull("postgres") - }) - - cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { - err := cid.KillRemove() - if err != nil { - t.Fatal(err) - } -} - -func TestPostgreSQL_Initialize(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying postgres object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*PostgreSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestPostgreSQL_CreateUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testPostgresReadOnlyRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - /* statements.CreationStatements = testBlockStatementRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - }*/ -} - -func TestPostgreSQL_RenewUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestPostgreSQL_RevokeUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = defaultPostgresRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testPostgresRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - -var testPostgresBlockStatementRoleSlice = []string{ - ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ -`, - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, - `GRANT "foo-role" TO "{{name}}";`, - `ALTER ROLE "{{name}}" SET search_path = foo;`, - `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, -} - -const defaultPostgresRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; -REVOKE USAGE ON SCHEMA public FROM "{{name}}"; - -DROP ROLE IF EXISTS "{{name}}"; -` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be2038c31..48d9b8880 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,10 +3,8 @@ package database import ( "fmt" "strings" - "time" "github.com/fatih/structs" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -50,16 +48,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -// pathConfigureBuiltinConnection returns a configured framework.Path setup to -// operate on builtin databases. -func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) -} - // pathConfigurePluginConnection returns a configured framework.Path setup to // operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) + return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler()) } // buildConfigConnectionPath reutns a configured framework.Path using the passed @@ -74,40 +66,12 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework Description: "Name of this DB type", }, - "connection_type": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "DB type (e.g. postgres)", - }, - "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If set, connection_url is verified by actually connecting to the database`, }, - "max_open_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of open connections to the database; -a zero uses the default value of two and a -negative value means unlimited`, - }, - - "max_idle_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of idle connections to the database; -a zero uses the value of max_open_connections -and a negative value disables idle connections. -If larger than max_open_connections it will be -reduced to the same size.`, - }, - - "max_connection_lifetime": &framework.FieldSchema{ - Type: framework.TypeString, - Default: "0s", - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; @@ -139,7 +103,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return nil, nil } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -180,40 +144,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // connectionWriteHandler returns a handler function for creating and updating // both builtin and plugin database types. -func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { +func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } - - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } - - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } - - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginName: data.Get("plugin_name").(string), + config := &DatabaseConfig{ + ConnectionDetails: data.Raw, + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) @@ -227,7 +163,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - db, err := factory(config, b.System(), b.logger) + db, err := PluginFactory(config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 6f62c79d9..d099ef178 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -156,7 +155,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } - statements := dbs.Statements{ + statements := Statements{ CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, @@ -183,10 +182,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/plugin.go similarity index 88% rename from builtin/logical/database/dbs/plugin.go rename to builtin/logical/database/plugin.go index 441f97ca0..5a6a8e328 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/plugin.go @@ -1,6 +1,7 @@ -package dbs +package database import ( + "errors" "fmt" "net/rpc" "sync" @@ -8,8 +9,47 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" ) +var ( + ErrEmptyPluginName = errors.New("empty plugin name") +) + +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { + if conf.PluginName == "" { + return nil, ErrEmptyPluginName + } + + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err + } + + db, err := newPluginClient(sys, pluginMeta) + if err != nil { + return nil, err + } + + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + + return db, nil +} + // handshakeConfigs are used to just do a basic handshake between // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin @@ -33,7 +73,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e } // DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close -// method to also call Close() on the plugin.Client. +// method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client sync.Mutex diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/plugin_test.go similarity index 99% rename from builtin/logical/database/dbs/plugin_test.go rename to builtin/logical/database/plugin_test.go index 60cb6814d..2ec01c955 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/plugin_test.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "crypto/sha256" diff --git a/command/plugin-exec.go b/command/plugin-exec.go index f0d6a8d51..70bc8ae1d 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/meta" ) @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := pluginutil.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins[pluginName] if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go new file mode 100644 index 000000000..6880640d1 --- /dev/null +++ b/helper/builtinplugins/builtin.go @@ -0,0 +1,8 @@ +package builtinplugins + +import "github.com/hashicorp/vault-plugins/database/mysql" + +var BuiltinPlugins = map[string]func() error{ + "mysql-database-plugin": mysql.Run, + // "postgres-database-plugin": postgres.Run, +} diff --git a/helper/pluginutil/builtin.go b/helper/pluginutil/builtin.go deleted file mode 100644 index 6a464bb82..000000000 --- a/helper/pluginutil/builtin.go +++ /dev/null @@ -1,6 +0,0 @@ -package pluginutil - -var BuiltinPlugins = map[string]func() error{ -// "mysql-database-plugin": mysql.Run, -// "postgres-database-plugin": postgres.Run, -} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index c6e4e4059..b9c15db22 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" @@ -53,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := pluginutil.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } From 2e23cf58b8e1b65a2d37a83c19860b26322ceea0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 17:19:29 -0700 Subject: [PATCH 058/162] Add postgres builtin plugin --- helper/builtinplugins/builtin.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 6880640d1..ceaf10edf 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,8 +1,11 @@ package builtinplugins -import "github.com/hashicorp/vault-plugins/database/mysql" +import ( + "github.com/hashicorp/vault-plugins/database/mysql" + "github.com/hashicorp/vault-plugins/database/postgresql" +) var BuiltinPlugins = map[string]func() error{ - "mysql-database-plugin": mysql.Run, - // "postgres-database-plugin": postgres.Run, + "mysql-database-plugin": mysql.Run, + "postgresql-database-plugin": postgresql.Run, } From 62d59e5f4e496b34436de3537200073a518c4475 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 6 Apr 2017 12:20:10 -0700 Subject: [PATCH 059/162] Move plugin code into sub directory --- builtin/logical/database/backend.go | 39 +-- builtin/logical/database/dbplugin/client.go | 148 ++++++++ .../{ => dbplugin}/databasemiddleware.go | 2 +- builtin/logical/database/dbplugin/plugin.go | 126 +++++++ .../database/{ => dbplugin}/plugin_test.go | 2 +- builtin/logical/database/dbplugin/server.go | 90 +++++ .../database/path_config_connection.go | 3 +- builtin/logical/database/path_roles.go | 11 +- builtin/logical/database/plugin.go | 324 ------------------ helper/pluginutil/runner.go | 5 + 10 files changed, 385 insertions(+), 365 deletions(-) create mode 100644 builtin/logical/database/dbplugin/client.go rename builtin/logical/database/{ => dbplugin}/databasemiddleware.go (99%) create mode 100644 builtin/logical/database/dbplugin/plugin.go rename builtin/logical/database/{ => dbplugin}/plugin_test.go (99%) create mode 100644 builtin/logical/database/dbplugin/server.go delete mode 100644 builtin/logical/database/plugin.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index a2fff4ba8..baa05a092 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -4,50 +4,23 @@ import ( "fmt" "strings" "sync" - "time" log "github.com/mgutz/logxi/v1" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) const databaseConfigPath = "database/dbs/" -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { - Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error - RevokeUser(statements Statements, username string) error - - Initialize(map[string]interface{}) error - Close() error - - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) -} - // DatabaseConfig is used by the Factory function to configure a DatabaseType // object. type DatabaseConfig struct { PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` // ConnectionDetails stores the database specific connection settings needed // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` -} - -// Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` } func Factory(conf *logical.BackendConfig) (logical.Backend, error) { @@ -83,12 +56,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]DatabaseType) + b.connections = make(map[string]dbplugin.DatabaseType) return &b } type databaseBackend struct { - connections map[string]DatabaseType + connections map[string]dbplugin.DatabaseType logger log.Logger *framework.Backend @@ -108,7 +81,7 @@ func (b *databaseBackend) closeAllDBs() { // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.DatabaseType, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { @@ -128,7 +101,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (Data return nil, err } - db, err = PluginFactory(&config, b.System(), b.logger) + db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go new file mode 100644 index 000000000..db6b3d1fd --- /dev/null +++ b/builtin/logical/database/dbplugin/client.go @@ -0,0 +1,148 @@ +package dbplugin + +import ( + "fmt" + "net/rpc" + "sync" + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// method to also call Kill() on the plugin.Client. +type DatabasePluginClient struct { + client *plugin.Client + sync.Mutex + + *databasePluginRPCClient +} + +func (dc *DatabasePluginClient) Close() error { + err := dc.databasePluginRPCClient.Close() + dc.client.Kill() + + return err +} + +// newPluginClient returns a databaseRPCClient with a connection to a running +// plugin. The client is wrapped in a DatabasePluginClient object to ensure the +// plugin is killed on call of Close(). +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": new(DatabasePlugin), + } + + client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) + if err != nil { + return nil, err + } + + // Connect via RPC + rpcClient, err := client.Client() + if err != nil { + return nil, err + } + + // Request the plugin + raw, err := rpcClient.Dispense("database") + if err != nil { + return nil, err + } + + // We should have a Greeter now! This feels like a normal interface + // implementation but is in fact over an RPC connection. + databaseRPC := raw.(*databasePluginRPCClient) + + return &DatabasePluginClient{ + client: client, + databasePluginRPCClient: databaseRPC, + }, nil +} + +// ---- RPC client domain ---- + +// databasePluginRPCClient impliments DatabaseType and is used on the client to +// make RPC calls to a plugin. +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() string { + var dbType string + //TODO: catch error + dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType) +} + +func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { + req := CreateUserRequest{ + Statements: statements, + Username: username, + Password: password, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { + req := RenewUserRequest{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { + req := RevokeUserRequest{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { + err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { + resp := &GenerateUsernameResponse{} + err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) + + return resp.Username, err +} + +func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { + resp := &GeneratePasswordResponse{} + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) + + return resp.Password, err +} + +func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { + resp := &GenerateExpirationResponse{} + err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) + + return resp.Expiration, err +} diff --git a/builtin/logical/database/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go similarity index 99% rename from builtin/logical/database/databasemiddleware.go rename to builtin/logical/database/dbplugin/databasemiddleware.go index 5892e8064..b4a980950 100644 --- a/builtin/logical/database/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -1,4 +1,4 @@ -package database +package dbplugin import ( "time" diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go new file mode 100644 index 000000000..994f3b0ce --- /dev/null +++ b/builtin/logical/database/dbplugin/plugin.go @@ -0,0 +1,126 @@ +package dbplugin + +import ( + "errors" + "net/rpc" + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" + log "github.com/mgutz/logxi/v1" +) + +var ( + ErrEmptyPluginName = errors.New("empty plugin name") +) + +// DatabaseType is the interface that all database objects must implement. +type DatabaseType interface { + Type() string + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error + + Initialize(map[string]interface{}) error + Close() error + + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) (string, error) +} + +// Statements set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. +func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { + if pluginName == "" { + return nil, ErrEmptyPluginName + } + + pluginMeta, err := sys.LookupPlugin(pluginName) + if err != nil { + return nil, err + } + + db, err := newPluginClient(sys, pluginMeta) + if err != nil { + return nil, err + } + + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + + return db, nil +} + +// handshakeConfigs are used to just do a basic handshake between +// a plugin and host. If the handshake fails, a user friendly error is shown. +// This prevents users from executing bad plugins or executing a plugin +// directory. It is a UX feature, not a security feature. +var handshakeConfig = plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "VAULT_DATABASE_PLUGIN", + MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", +} + +type DatabasePlugin struct { + impl DatabaseType +} + +func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { + return &databasePluginRPCServer{impl: d.impl}, nil +} + +func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { + return &databasePluginRPCClient{client: c}, nil +} + +// ---- RPC Request Args Domain ---- + +type CreateUserRequest struct { + Statements Statements + Username string + Password string + Expiration string +} + +type RenewUserRequest struct { + Statements Statements + Username string + Expiration string +} + +type RevokeUserRequest struct { + Statements Statements + Username string +} + +// ---- RPC Response Args Domain ---- + +type GenerateUsernameResponse struct { + Username string +} +type GenerateExpirationResponse struct { + Expiration string +} +type GeneratePasswordResponse struct { + Password string +} diff --git a/builtin/logical/database/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go similarity index 99% rename from builtin/logical/database/plugin_test.go rename to builtin/logical/database/dbplugin/plugin_test.go index 2ec01c955..849e1ebbf 100644 --- a/builtin/logical/database/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,4 +1,4 @@ -package database +package dbplugin import ( "crypto/sha256" diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go new file mode 100644 index 000000000..018d9b8db --- /dev/null +++ b/builtin/logical/database/dbplugin/server.go @@ -0,0 +1,90 @@ +package dbplugin + +import ( + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// NewPluginServer is called from within a plugin and wraps the provided +// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// RPC server. +func NewPluginServer(db DatabaseType) { + dbPlugin := &DatabasePlugin{ + impl: db, + } + + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": dbPlugin, + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + TLSProvider: pluginutil.VaultPluginTLSProvider, + }) +} + +// ---- RPC server domain ---- + +// databasePluginRPCServer impliments DatabaseType and is run inside a plugin +type databasePluginRPCServer struct { + impl DatabaseType +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + *resp = ds.impl.Type() + return nil +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { + err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { + err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { + err := ds.impl.RevokeUser(args.Statements, args.Username) + + return err +} + +func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { + err := ds.impl.Initialize(args) + + return err +} + +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { + var err error + resp.Username, err = ds.impl.GenerateUsername(args) + + return err +} + +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { + var err error + resp.Password, err = ds.impl.GeneratePassword() + + return err +} + +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { + var err error + resp.Expiration, err = ds.impl.GenerateExpiration(args) + + return err +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 48d9b8880..4af6e70a0 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/fatih/structs" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -163,7 +164,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Lock() defer b.Unlock() - db, err := PluginFactory(config, b.System(), b.logger) + db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index d099ef178..b3ef6f753 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -155,7 +156,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } - statements := Statements{ + statements := dbplugin.Statements{ CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, @@ -182,10 +183,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/plugin.go b/builtin/logical/database/plugin.go deleted file mode 100644 index 5a6a8e328..000000000 --- a/builtin/logical/database/plugin.go +++ /dev/null @@ -1,324 +0,0 @@ -package database - -import ( - "errors" - "fmt" - "net/rpc" - "sync" - "time" - - "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/logical" - log "github.com/mgutz/logxi/v1" -) - -var ( - ErrEmptyPluginName = errors.New("empty plugin name") -) - -// PluginFactory is used to build plugin database types. It wraps the database -// object in a logging and metrics middleware. -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginName == "" { - return nil, ErrEmptyPluginName - } - - pluginMeta, err := sys.LookupPlugin(conf.PluginName) - if err != nil { - return nil, err - } - - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err - } - - // Wrap with metrics middleware - db = &databaseMetricsMiddleware{ - next: db, - typeStr: db.Type(), - } - - // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: db.Type(), - logger: logger, - } - - return db, nil -} - -// handshakeConfigs are used to just do a basic handshake between -// a plugin and host. If the handshake fails, a user friendly error is shown. -// This prevents users from executing bad plugins or executing a plugin -// directory. It is a UX feature, not a security feature. -var handshakeConfig = plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "VAULT_DATABASE_PLUGIN", - MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", -} - -type DatabasePlugin struct { - impl DatabaseType -} - -func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { - return &databasePluginRPCServer{impl: d.impl}, nil -} - -func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { - return &databasePluginRPCClient{client: c}, nil -} - -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close -// method to also call Kill() on the plugin.Client. -type DatabasePluginClient struct { - client *plugin.Client - sync.Mutex - - *databasePluginRPCClient -} - -func (dc *DatabasePluginClient) Close() error { - err := dc.databasePluginRPCClient.Close() - dc.client.Kill() - - return err -} - -// newPluginClient returns a databaseRPCClient with a connection to a running -// plugin. The client is wrapped in a DatabasePluginClient object to ensure the -// plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { - // pluginMap is the map of plugins we can dispense. - var pluginMap = map[string]plugin.Plugin{ - "database": new(DatabasePlugin), - } - - client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) - if err != nil { - return nil, err - } - - // Connect via RPC - rpcClient, err := client.Client() - if err != nil { - return nil, err - } - - // Request the plugin - raw, err := rpcClient.Dispense("database") - if err != nil { - return nil, err - } - - // We should have a Greeter now! This feels like a normal interface - // implementation but is in fact over an RPC connection. - databaseRPC := raw.(*databasePluginRPCClient) - - return &DatabasePluginClient{ - client: client, - databasePluginRPCClient: databaseRPC, - }, nil -} - -// NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implimentation in a databasePluginRPCServer object and starts a -// RPC server. -func NewPluginServer(db DatabaseType) { - dbPlugin := &DatabasePlugin{ - impl: db, - } - - // pluginMap is the map of plugins we can dispense. - var pluginMap = map[string]plugin.Plugin{ - "database": dbPlugin, - } - - plugin.Serve(&plugin.ServeConfig{ - HandshakeConfig: handshakeConfig, - Plugins: pluginMap, - TLSProvider: pluginutil.VaultPluginTLSProvider, - }) -} - -// ---- RPC client domain ---- - -// databasePluginRPCClient impliments DatabaseType and is used on the client to -// make RPC calls to a plugin. -type databasePluginRPCClient struct { - client *rpc.Client -} - -func (dr *databasePluginRPCClient) Type() string { - var dbType string - //TODO: catch error - dr.client.Call("Plugin.Type", struct{}{}, &dbType) - - return fmt.Sprintf("plugin-%s", dbType) -} - -func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { - req := CreateUserRequest{ - Statements: statements, - Username: username, - Password: password, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { - req := RenewUserRequest{ - Statements: statements, - Username: username, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { - req := RevokeUserRequest{ - Statements: statements, - Username: username, - } - - err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { - err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Close() error { - err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - resp := &GenerateUsernameResponse{} - err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - - return resp.Username, err -} - -func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - resp := &GeneratePasswordResponse{} - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - - return resp.Password, err -} - -func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - resp := &GenerateExpirationResponse{} - err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - - return resp.Expiration, err -} - -// ---- RPC server domain ---- - -// databasePluginRPCServer impliments DatabaseType and is run inside a plugin -type databasePluginRPCServer struct { - impl DatabaseType -} - -func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = ds.impl.Type() - return nil -} - -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { - err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { - err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { - err := ds.impl.RevokeUser(args.Statements, args.Username) - - return err -} - -func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { - err := ds.impl.Initialize(args) - - return err -} - -func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { - ds.impl.Close() - return nil -} - -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { - var err error - resp.Username, err = ds.impl.GenerateUsername(args) - - return err -} - -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { - var err error - resp.Password, err = ds.impl.GeneratePassword() - - return err -} - -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { - var err error - resp.Expiration, err = ds.impl.GenerateExpiration(args) - - return err -} - -// ---- Request Args Domain ---- - -type CreateUserRequest struct { - Statements Statements - Username string - Password string - Expiration string -} - -type RenewUserRequest struct { - Statements Statements - Username string - Expiration string -} - -type RevokeUserRequest struct { - Statements Statements - Username string -} - -// ---- Response Args Domain ---- - -type GenerateUsernameResponse struct { - Username string -} -type GenerateExpirationResponse struct { - Expiration string -} -type GeneratePasswordResponse struct { - Password string -} diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 143a4c839..90569dd9a 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -12,6 +12,11 @@ type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +type LookWrapper interface { + Looker + Wrapper +} + type PluginRunner struct { Name string `json:"name"` Command string `json:"command"` From 93136ea51e3c98803c232db4e085c0a99d136ff8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 7 Apr 2017 15:50:03 -0700 Subject: [PATCH 060/162] Add backend test --- builtin/logical/database/backend_test.go | 567 ++++++++++++++++++ .../database/path_config_connection.go | 2 + builtin/logical/database/path_roles.go | 2 +- builtin/logical/database/secret_creds.go | 2 +- command/plugin-exec.go | 2 +- helper/builtinplugins/builtin.go | 19 +- vault/plugin_catalog.go | 2 +- vault/testing.go | 45 +- 8 files changed, 633 insertions(+), 8 deletions(-) create mode 100644 builtin/logical/database/backend_test.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 000000000..5cb84476d --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,567 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "log" + "net" + "os" + "reflect" + "strings" + "sync" + "testing" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/postgresql", + Data: map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_url": retURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return error and try again + return fmt.Errorf("err:%s resp:%#v\n", err, resp) + } + if resp == nil { + t.Fatal("expected warning") + } + + return nil + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { + core, _, token, ln := vault.TestCoreUnsealedWithListener(t) + http.TestServerWithListener(t, ln, "", core) + sys := vault.TestDynamicSystemView(core) + vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0])) + + return core, ln, sys, token +} + +func TestBackend_PluginMain(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin") + f() +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": configData, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if testCredsByCount(t, credsResp, connURL) != 2 { + t.Fatalf("Got wrong number of creds") + } + + // Revoke creds + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: config.StorageView, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": credsResp.Data["username"], + "role": "plugin-role-test", + }, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if testCredsByCount(t, credsResp, connURL) != -1 { + t.Fatalf("Got wrong number of creds") + } + +} + +func TestBackend_roleCrud(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := dbplugin.Statements{ + CreationStatements: testRole, + RevocationStatements: defaultRevocationSQL, + } + + var actual dbplugin.Statements + if err := mapstructure.Decode(resp.Data, &actual); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual) + } + + // Delete the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") + } +} + +func TestBackend_roleReadOnly(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a readonly role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testReadOnlyRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if i := testCredsByCount(t, credsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + // Get readonly creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + readOnlyCredsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp) + } + + if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil { + t.Fatal("Read only creds should return error on table creation") + } + + if err := testCreateTable(t, credsResp, connURL); err != nil { + t.Fatalf("Error on table creation: %s", err) + } +} + +func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + return returnedRows() +} + +func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1) + + fmt.Println(connURL) + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);") + if err != nil { + return err + } + + if i, _ := r.RowsAffected(); i != 1 { + return errors.New("Did not create db") + } + + return nil +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +REVOKE ALL ON SCHEMA public FROM "{{name}}"; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4af6e70a0..1b8a65831 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -172,10 +172,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { err = db.Initialize(config.ConnectionDetails) if err != nil { if !strings.Contains(err.Error(), "Error Initializing Connection") { + db.Close() return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } if verifyConnection { + db.Close() return logical.ErrorResponse("Could not verify connection"), nil } } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index b3ef6f753..a6989df24 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -105,7 +105,7 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statments": role.Statements.CreationStatements, + "creation_statements": role.Statements.CreationStatements, "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, "renew_statements": role.Statements.RenewStatements, diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2b63ea1f8..353541c0c 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("no role name was provided") } role, err := b.Role(req.Storage, roleNameRaw.(string)) diff --git a/command/plugin-exec.go b/command/plugin-exec.go index 70bc8ae1d..575be14b7 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := builtinplugins.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName) if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index ceaf10edf..ba3769c90 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -5,7 +5,20 @@ import ( "github.com/hashicorp/vault-plugins/database/postgresql" ) -var BuiltinPlugins = map[string]func() error{ - "mysql-database-plugin": mysql.Run, - "postgresql-database-plugin": postgresql.Run, +var BuiltinPlugins *builtinPlugins = &builtinPlugins{ + plugins: map[string]func() error{ + "mysql-database-plugin": mysql.Run, + "postgresql-database-plugin": postgresql.Run, + }, +} + +// The list of builtin plugins should not be changed by any other package, so we +// store them in an unexported variable in this unexported struct. +type builtinPlugins struct { + plugins map[string]func() error +} + +func (b *builtinPlugins) Get(name string) (func() error, bool) { + f, ok := b.plugins[name] + return f, ok } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b9c15db22..a42f85ec1 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -54,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } diff --git a/vault/testing.go b/vault/testing.go index 7b914bbdb..fdf55b4e5 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -8,9 +8,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "io" "net" "net/http" + "os" "os/exec" + "path/filepath" + "strings" "testing" "time" @@ -306,7 +310,46 @@ func TestKeyCopy(key []byte) []byte { } func TestDynamicSystemView(c *Core) *dynamicSystemView { - return &dynamicSystemView{c, nil} + me := &MountEntry{ + Config: MountConfig{ + DefaultLeaseTTL: 24 * time.Hour, + MaxLeaseTTL: 2 * 24 * time.Hour, + }, + } + + return &dynamicSystemView{c, me} +} + +func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { + parts := strings.Split(command, " ") + + file, err := os.Open(parts[0]) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + hash := sha256.New() + + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0]) + if err != nil { + t.Fatal(err) + } + c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) + + parts[0] = filepath.Base(parts[0]) + command = strings.Join(parts, " ") + + err = c.pluginCatalog.Set(name, command, sum) + if err != nil { + t.Fatal(err) + } } var testLogicalBackends = map[string]logical.Factory{} From 459e3eda4e0b0e2d47e6b3607c239d4c94ee5e78 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 10:35:16 -0700 Subject: [PATCH 061/162] Update backend tests --- builtin/logical/database/backend_test.go | 314 ++++++++---------- .../database/path_config_connection.go | 2 +- 2 files changed, 148 insertions(+), 168 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 5cb84476d..fc41cf3cd 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -2,13 +2,11 @@ package database import ( "database/sql" - "errors" "fmt" "log" "net" "os" "reflect" - "strings" "sync" "testing" @@ -209,8 +207,8 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if testCredsByCount(t, credsResp, connURL) != 2 { - t.Fatalf("Got wrong number of creds") + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") } // Revoke creds @@ -229,12 +227,153 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - if testCredsByCount(t, credsResp, connURL) != -1 { - t.Fatalf("Got wrong number of creds") + if testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should not exist") } } +func TestBackend_connectionCrud(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": "test", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Update the connection + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read connection + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": data, + } + req.Operation = logical.ReadOperation + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } + + // Reset Connection + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "reset/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // Delete Connection + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read connection + req.Operation = logical.ReadOperation + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") + } +} + func TestBackend_roleCrud(t *testing.T) { _, ln, sys, _ := getCore(t) defer ln.Close() @@ -346,119 +485,7 @@ func TestBackend_roleCrud(t *testing.T) { } } -func TestBackend_roleReadOnly(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() - - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - config.System = sys - - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - defer b.Cleanup() - - cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) - defer cleanup() - - // Configure a connection - data := map[string]interface{}{ - "connection_url": connURL, - "plugin_name": "postgresql-database-plugin", - } - req := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: data, - } - resp, err := b.HandleRequest(req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Create a role - data = map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testRole, - "default_ttl": "5m", - "max_ttl": "10m", - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "roles/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Create a readonly role - data = map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testReadOnlyRole, - "default_ttl": "5m", - "max_ttl": "10m", - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "roles/plugin-readonly-role-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Get creds - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - credsResp, err := b.HandleRequest(req) - if err != nil || (credsResp != nil && credsResp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, credsResp) - } - - if i := testCredsByCount(t, credsResp, connURL); i != 2 { - t.Fatalf("Got wrong number of creds got %d, expected 2", i) - } - - // Get readonly creds - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-readonly-role-test", - Storage: config.StorageView, - Data: data, - } - readOnlyCredsResp, err := b.HandleRequest(req) - if err != nil || (credsResp != nil && credsResp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp) - } - - if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 { - t.Fatalf("Got wrong number of creds got %d, expected 2", i) - } - - if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil { - t.Fatal("Read only creds should return error on table creation") - } - - if err := testCreateTable(t, credsResp, connURL); err != nil { - t.Fatalf("Error on table creation: %s", err) - } -} - -func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int { +func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { var d struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` @@ -500,44 +527,7 @@ func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int return i } - return returnedRows() -} - -func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - t.Fatal(err) - } - - connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1) - - fmt.Println(connURL) - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);") - if err != nil { - return err - } - - if i, _ := r.RowsAffected(); i != 1 { - return errors.New("Did not create db") - } - - return nil + return returnedRows() == 2 } const testRole = ` @@ -548,16 +538,6 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -const testReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -REVOKE ALL ON SCHEMA public FROM "{{name}}"; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - const defaultRevocationSQL = ` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 1b8a65831..7589669a4 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -171,7 +171,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { err = db.Initialize(config.ConnectionDetails) if err != nil { - if !strings.Contains(err.Error(), "Error Initializing Connection") { + if !strings.Contains(err.Error(), "error initalizing connection") { db.Close() return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } From bbbd81220c92b4f3f7818b74d586994f9d13ea4c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 12:24:16 -0700 Subject: [PATCH 062/162] Update the interface for plugins removing functions for creating creds --- builtin/logical/database/dbplugin/client.go | 37 ++------ .../database/dbplugin/databasemiddleware.go | 93 ++----------------- builtin/logical/database/dbplugin/plugin.go | 24 ++--- builtin/logical/database/dbplugin/server.go | 28 +----- builtin/logical/database/path_role_create.go | 19 +--- builtin/logical/database/secret_creds.go | 4 +- 6 files changed, 28 insertions(+), 177 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index db6b3d1fd..0dae61d27 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -78,20 +78,20 @@ func (dr *databasePluginRPCClient) Type() string { return fmt.Sprintf("plugin-%s", dbType) } -func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { +func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { req := CreateUserRequest{ - Statements: statements, - Username: username, - Password: password, - Expiration: expiration, + Statements: statements, + UsernamePrefix: usernamePrefix, + Expiration: expiration, } - err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + var resp CreateUserResponse + err = dr.client.Call("Plugin.CreateUser", req, &resp) - return err + return resp.Username, resp.Password, err } -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { req := RenewUserRequest{ Statements: statements, Username: username, @@ -125,24 +125,3 @@ func (dr *databasePluginRPCClient) Close() error { return err } - -func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - resp := &GenerateUsernameResponse{} - err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - - return resp.Username, err -} - -func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - resp := &GeneratePasswordResponse{} - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - - return resp.Password, err -} - -func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - resp := &GenerateExpirationResponse{} - err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - - return resp.Expiration, err -} diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index b4a980950..2748f2f11 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -20,7 +20,7 @@ func (mw *databaseTracingMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -28,10 +28,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) } - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -75,39 +75,6 @@ func (mw *databaseTracingMiddleware) Close() (err error) { return mw.next.Close() } -func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr) - } - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr) - } - return mw.next.GeneratePassword() -} - -func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr) - } - return mw.next.GenerateExpiration(duration) -} - // ---- Metrics Middleware Domain ---- type databaseMetricsMiddleware struct { @@ -120,7 +87,7 @@ func (mw *databaseMetricsMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) @@ -133,10 +100,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) @@ -199,51 +166,3 @@ func (mw *databaseMetricsMiddleware) Close() (err error) { metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) return mw.next.Close() } - -func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateUsername"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1) - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GeneratePassword"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1) - return mw.next.GeneratePassword() -} - -func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1) - return mw.next.GenerateExpiration(duration) -} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 994f3b0ce..5cd24e879 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -17,16 +17,12 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error + CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) + RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error Initialize(map[string]interface{}) error Close() error - - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) } // Statements set in role creation and passed into the database type's functions. @@ -96,16 +92,15 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e // ---- RPC Request Args Domain ---- type CreateUserRequest struct { - Statements Statements - Username string - Password string - Expiration string + Statements Statements + UsernamePrefix string + Expiration time.Time } type RenewUserRequest struct { Statements Statements Username string - Expiration string + Expiration time.Time } type RevokeUserRequest struct { @@ -115,12 +110,7 @@ type RevokeUserRequest struct { // ---- RPC Response Args Domain ---- -type GenerateUsernameResponse struct { +type CreateUserResponse struct { Username string -} -type GenerateExpirationResponse struct { - Expiration string -} -type GeneratePasswordResponse struct { Password string } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 018d9b8db..2dddbaffd 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,8 +1,6 @@ package dbplugin import ( - "time" - "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -39,8 +37,9 @@ func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { return nil } -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { - err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { + var err error + resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration) return err } @@ -67,24 +66,3 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { ds.impl.Close() return nil } - -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { - var err error - resp.Username, err = ds.impl.GenerateUsername(args) - - return err -} - -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { - var err error - resp.Password, err = ds.impl.GeneratePassword() - - return err -} - -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { - var err error - resp.Expiration, err = ds.impl.GenerateExpiration(args) - - return err -} diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index d379ef267..5a16c8926 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -48,24 +49,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - // Generate the username, password and expiration - username, err := db.GenerateUsername(req.DisplayName) - if err != nil { - return nil, err - } - - password, err := db.GeneratePassword() - if err != nil { - return nil, err - } - - expiration, err := db.GenerateExpiration(role.DefaultTTL) - if err != nil { - return nil, err - } + expiration := time.Now().Add(role.DefaultTTL) // Create the user - err = db.CreateUser(role.Statements, username, password, expiration) + username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 353541c0c..5701e373a 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -58,9 +58,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - expiration := expireTime.Format("2006-01-02 15:04:05-0700") - - err := db.RenewUser(role.Statements, username, expiration) + err := db.RenewUser(role.Statements, username, expireTime) if err != nil { return nil, err } From db91a8054095f13bec39e2961675e062c0a9d5ef Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 14:12:28 -0700 Subject: [PATCH 063/162] Update plugin test --- builtin/logical/database/backend_test.go | 2 +- .../logical/database/dbplugin/plugin_test.go | 199 +++++------------- vault/testing.go | 13 +- 3 files changed, 53 insertions(+), 161 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index fc41cf3cd..5b3a0db42 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -81,7 +81,7 @@ func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, strin core, _, token, ln := vault.TestCoreUnsealedWithListener(t) http.TestServerWithListener(t, ln, "", core) sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0])) + vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain") return core, ln, sys, token } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 849e1ebbf..7909bbd4e 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,17 +1,13 @@ -package dbplugin +package dbplugin_test import ( - "crypto/sha256" - "encoding/hex" "errors" - "fmt" - "io" "net" "os" - "os/exec" "testing" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" @@ -21,27 +17,26 @@ import ( type mockPlugin struct { users map[string][]string - CredentialsProducer } func (m *mockPlugin) Type() string { return "mock" } -func (m *mockPlugin) CreateUser(statements Statements, username, password, expiration string) error { - err := errors.New("err") - if username == "" || password == "" || expiration == "" { - return err +func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + err = errors.New("err") + if usernamePrefix == "" || expiration.IsZero() { + return "", "", err } - if _, ok := m.users[username]; ok { - return err + if _, ok := m.users[usernamePrefix]; ok { + return "", "", err } - m.users[username] = []string{password, expiration} + m.users[usernamePrefix] = []string{password} - return nil + return usernamePrefix, "test", nil } -func (m *mockPlugin) RenewUser(statements Statements, username, expiration string) error { +func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { err := errors.New("err") - if username == "" || expiration == "" { + if username == "" || expiration.IsZero() { return err } @@ -51,7 +46,7 @@ func (m *mockPlugin) RenewUser(statements Statements, username, expiration strin return nil } -func (m *mockPlugin) RevokeUser(statements Statements, username string) error { +func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { err := errors.New("err") if username == "" { return err @@ -77,40 +72,11 @@ func (m *mockPlugin) Close() error { return nil } -func getConf(t *testing.T) *DatabaseConfig { - command := fmt.Sprintf("%s -test.run=TestPlugin_Main", os.Args[0]) - cmd := exec.Command(os.Args[0]) - hash := sha256.New() - - file, err := os.Open(cmd.Path) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - _, err = io.Copy(hash, file) - if err != nil { - t.Fatal(err) - } - - sum := hash.Sum(nil) - - conf := &DatabaseConfig{ - DatabaseType: pluginTypeName, - PluginCommand: command, - PluginChecksum: hex.EncodeToString(sum), - ConnectionDetails: map[string]interface{}{ - "test": true, - }, - } - - return conf -} - func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { core, _, _, ln := vault.TestCoreUnsealedWithListener(t) http.TestServerWithListener(t, ln, "", core) sys := vault.TestDynamicSystemView(core) + vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main") return core, ln, sys } @@ -123,24 +89,26 @@ func TestPlugin_Main(t *testing.T) { } plugin := &mockPlugin{ - users: make(map[string][]string), - CredentialsProducer: &sqlCredentialsProducer{5, 50}, + users: make(map[string][]string), } - NewPluginServer(plugin) + dbplugin.NewPluginServer(plugin) } func TestPlugin_Initialize(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - dbRaw, err := PluginFactory(conf, sys, &log.NullLogger{}) + dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } - err = dbRaw.Initialize(conf.ConnectionDetails) + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = dbRaw.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } @@ -155,97 +123,61 @@ func TestPlugin_CreateUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - username, err := db.GenerateUsername("test") + us, pw, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } + if us != "test" || pw != "test" { + t.Fatal("expected username and password to be 'test'") + } - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } // try and save the same user again to verify it saved the first time, this // should return an error - err = db.CreateUser(Statements{}, username, password, expiration) + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err == nil { t.Fatal("expected an error, user wasn't created correctly") } - - // Create one more user - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } } func TestPlugin_RenewUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - username, err := db.GenerateUsername("test") + us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(Statements{}, username, expiration) + err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -255,69 +187,34 @@ func TestPlugin_RevokeUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(Statements{}, username, password, expiration) + us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } // Test default revoke statememts - err = db.RevokeUser(Statements{}, username) + err = db.RevokeUser(dbplugin.Statements{}, us) if err != nil { t.Fatalf("err: %s", err) } // Try adding the same username back so we can verify it was removed - err = db.CreateUser(Statements{}, username, password, expiration) + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // try once more - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.RevokeUser(Statements{}, username) - if err != nil { - t.Fatalf("err: %s", err) - } - } diff --git a/vault/testing.go b/vault/testing.go index fdf55b4e5..b2fe36b33 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -14,7 +14,6 @@ import ( "os" "os/exec" "path/filepath" - "strings" "testing" "time" @@ -320,10 +319,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView { return &dynamicSystemView{c, me} } -func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { - parts := strings.Split(command, " ") - - file, err := os.Open(parts[0]) +func TestAddTestPlugin(t testing.TB, c *Core, name, testFunc string) { + file, err := os.Open(os.Args[0]) if err != nil { t.Fatal(err) } @@ -337,15 +334,13 @@ func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { } sum := hash.Sum(nil) - c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0]) + c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0]) if err != nil { t.Fatal(err) } c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) - parts[0] = filepath.Base(parts[0]) - command = strings.Join(parts, " ") - + command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc) err = c.pluginCatalog.Set(name, command, sum) if err != nil { t.Fatal(err) From f6ff3b11468c89724ae0a94a6aefbbcce7dcf349 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 15:36:59 -0700 Subject: [PATCH 064/162] Add a flag to tell plugins to verify the connection was successful --- builtin/logical/database/backend.go | 2 +- builtin/logical/database/dbplugin/client.go | 9 +++++++-- .../database/dbplugin/databasemiddleware.go | 10 +++++----- builtin/logical/database/dbplugin/plugin.go | 8 ++++++-- builtin/logical/database/dbplugin/server.go | 4 ++-- .../logical/database/path_config_connection.go | 15 ++++----------- 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index baa05a092..4cf542d95 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -106,7 +106,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return nil, err } - err = db.Initialize(config.ConnectionDetails) + err = db.Initialize(config.ConnectionDetails, true) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 0dae61d27..da39ed425 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -114,8 +114,13 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st return err } -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { - err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { + req := InitializeRequest{ + Config: conf, + VerifyConnection: verifyConnection, + } + + err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) return err } diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 2748f2f11..1df7be3bb 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -53,15 +53,15 @@ func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username return mw.next.RevokeUser(statements, username) } -func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) { +func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) } - return mw.next.Initialize(conf) + return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { @@ -135,7 +135,7 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username return mw.next.RevokeUser(statements, username) } -func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (err error) { +func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) @@ -148,7 +148,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (er metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) - return mw.next.Initialize(conf) + return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseMetricsMiddleware) Close() (err error) { diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 5cd24e879..39655bf46 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -21,12 +21,11 @@ type DatabaseType interface { RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error - Initialize(map[string]interface{}) error + Initialize(config map[string]interface{}, verifyConnection bool) error Close() error } // Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` @@ -91,6 +90,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e // ---- RPC Request Args Domain ---- +type InitializeRequest struct { + Config map[string]interface{} + VerifyConnection bool +} + type CreateUserRequest struct { Statements Statements UsernamePrefix string diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 2dddbaffd..54b05338c 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -56,8 +56,8 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct return err } -func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { - err := ds.impl.Initialize(args) +func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { + err := ds.impl.Initialize(args.Config, args.VerifyConnection) return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 7589669a4..8e78aa425 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -2,7 +2,6 @@ package database import ( "fmt" - "strings" "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" @@ -169,23 +168,17 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } - err = db.Initialize(config.ConnectionDetails) + err = db.Initialize(config.ConnectionDetails, verifyConnection) if err != nil { - if !strings.Contains(err.Error(), "error initalizing connection") { - db.Close() - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - if verifyConnection { - db.Close() - return logical.ErrorResponse("Could not verify connection"), nil - } + db.Close() + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } if _, ok := b.connections[name]; ok { // Close and remove the old connection err := b.connections[name].Close() if err != nil { + db.Close() return nil, err } From 8071aed75890da68eee665ce63e066cd21297dbb Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 17:12:52 -0700 Subject: [PATCH 065/162] Mlock the plugin process --- builtin/logical/database/backend.go | 2 ++ builtin/logical/database/dbplugin/server.go | 8 ++++++ helper/pluginutil/runner.go | 32 ++++++++++++++++++++- helper/pluginutil/tls.go | 4 --- logical/system_view.go | 10 +++++++ vault/core.go | 5 +++- vault/dynamic_system_view.go | 7 +++++ vault/plugin_catalog.go | 2 ++ 8 files changed, 64 insertions(+), 6 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4cf542d95..618ffac6f 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -76,6 +76,8 @@ func (b *databaseBackend) closeAllDBs() { for _, db := range b.connections { db.Close() } + + b.connections = nil } // This function is used to retrieve a database object either from the cached diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 54b05338c..5c1b41a3d 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,6 +1,8 @@ package dbplugin import ( + "fmt" + "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -18,6 +20,12 @@ func NewPluginServer(db DatabaseType) { "database": dbPlugin, } + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return + } + plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 90569dd9a..4d66d8706 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -3,15 +3,29 @@ package pluginutil import ( "crypto/sha256" "fmt" + "os" "os/exec" + "time" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/mlock" +) + +var ( + // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the + // plugin. + PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +type Wrapper interface { + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + MlockDisabled() bool +} + type LookWrapper interface { Looker Wrapper @@ -22,6 +36,7 @@ type PluginRunner struct { Command string `json:"command"` Args []string `json:"args"` Sha256 []byte `json:"sha256"` + Builtin bool `json:"builtin"` } func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { @@ -44,10 +59,17 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return nil, err } - // Add the response wrap token to the ENV of the plugin + mlock := "true" + if wrapper.MlockDisabled() { + mlock = "false" + } + cmd := exec.Command(r.Command, r.Args...) cmd.Env = append(cmd.Env, env...) + // Add the response wrap token to the ENV of the plugin cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + // Add the mlock setting to the ENV of the plugin + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, mlock)) secureConfig := &plugin.SecureConfig{ Checksum: r.Sha256, @@ -64,3 +86,11 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } + +func OptionallyEnableMlock() error { + if os.Getenv(PluginMlockEnabled) == "true" { + return mlock.LockMemory() + } + + return nil +} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 63ae2932f..c7aa42ee6 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -29,10 +29,6 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -type Wrapper interface { - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) -} - // GenerateCACert returns a CA cert used to later sign the certificates for the // plugin client and server. func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { diff --git a/logical/system_view.go b/logical/system_view.go index a9626bc50..b69f27090 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -44,7 +44,12 @@ type SystemView interface { // token used to unwrap. ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + // LookupPlugin looks into the plugin catalog for a plugin with the given + // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) + + // MlockDisabled returns the configuration setting for DisableMlock. + MlockDisabled() bool } type StaticSystemView struct { @@ -54,6 +59,7 @@ type StaticSystemView struct { TaintedVal bool CachingDisabledVal bool Primary bool + DisableMlock bool ReplicationStateVal consts.ReplicationState } @@ -88,3 +94,7 @@ func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") } + +func (d StaticSystemView) MlockDisabled() bool { + return d.DisableMlock +} diff --git a/vault/core.go b/vault/core.go index ffd36683b..9a2f1900e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -332,7 +332,7 @@ type Core struct { // uiEnabled indicates whether Vault Web UI is enabled or not uiEnabled bool - // pluginDirectory is the location vault will look for plugins + // pluginDirectory is the location vault will look for plugin binaries pluginDirectory string // vaultBinaryLocation is used to run builtin plugins in secure mode @@ -343,6 +343,8 @@ type Core struct { // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog + + disableMlock bool } // CoreConfig is used to parameterize a core @@ -449,6 +451,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterListenerShutdownSuccessCh: make(chan struct{}), vaultBinaryLocation: conf.VaultBinaryLocation, vaultBinarySHA256: conf.VaultBinarySHA256, + disableMlock: conf.DisableMlock, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index f318f3ab1..ca2b89d6c 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -116,6 +116,13 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim return resp.WrapInfo.Token, nil } +// LookupPlugin looks for a plugin with the given name in the plugin catalog. It +// returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { return d.core.pluginCatalog.Get(name) } + +// MlockDisabled returns the configuration setting "DisableMlock". +func (d dynamicSystemView) MlockDisabled() bool { + return d.core.disableMlock +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index a42f85ec1..737f0c26b 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -63,6 +63,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { Command: c.vaultCommand, Args: []string{"plugin-exec", name}, Sha256: c.vaultSHA256, + Builtin: true, }, nil } @@ -93,6 +94,7 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { Command: command, Args: args, Sha256: sha256, + Builtin: false, } buf, err := json.Marshal(entry) From c85b7be22f1939d23b5d68b53a6bab4e529943a4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 18:38:34 -0700 Subject: [PATCH 066/162] Remove unnecessary abstraction --- .../logical/database/path_config_connection.go | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 8e78aa425..c242aa339 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -51,15 +51,8 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew // pathConfigurePluginConnection returns a configured framework.Path setup to // operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler()) -} - -// buildConfigConnectionPath reutns a configured framework.Path using the passed -// in operation functions to complete the request. Used to distinguish calls -// between builtin and plugin databases. -func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework.OperationFunc) *framework.Path { return &framework.Path{ - Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), + Pattern: fmt.Sprintf("config/%s", framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, @@ -80,9 +73,9 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: updateOp, - logical.ReadOperation: readOp, - logical.DeleteOperation: deleteOp, + logical.UpdateOperation: b.connectionWriteHandler(), + logical.ReadOperation: b.connectionReadHandler(), + logical.DeleteOperation: b.connectionDeleteHandler(), }, HelpSynopsis: pathConfigConnectionHelpSyn, From 128f25c13d6dc93a509313d0ab2cd16e980f43f2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 11 Apr 2017 11:50:34 -0700 Subject: [PATCH 067/162] Update help text and comments --- builtin/logical/database/dbplugin/client.go | 4 +- .../database/dbplugin/databasemiddleware.go | 4 + builtin/logical/database/dbplugin/plugin.go | 4 + builtin/logical/database/dbplugin/server.go | 5 +- .../database/path_config_connection.go | 87 ++++--- builtin/logical/database/path_role_create.go | 78 ++++--- builtin/logical/database/path_roles.go | 218 ++++++++++-------- builtin/logical/database/secret_creds.go | 190 +++++++-------- logical/system_view.go | 4 +- 9 files changed, 323 insertions(+), 271 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index da39ed425..5bdc3a01a 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -10,7 +10,7 @@ import ( "github.com/hashicorp/vault/helper/pluginutil" ) -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close // method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client @@ -64,7 +64,7 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn // ---- RPC client domain ---- -// databasePluginRPCClient impliments DatabaseType and is used on the client to +// databasePluginRPCClient implements DatabaseType and is used on the client to // make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 1df7be3bb..2137cd9c3 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -9,6 +9,8 @@ import ( // ---- Tracing Middleware Domain ---- +// databaseTracingMiddleware wraps a implementation of DatabaseType and executes +// trace logging on function call. type databaseTracingMiddleware struct { next DatabaseType logger log.Logger @@ -77,6 +79,8 @@ func (mw *databaseTracingMiddleware) Close() (err error) { // ---- Metrics Middleware Domain ---- +// databaseMetricsMiddleware wraps an implementation of DatabaseTypes and on +// function call logs metrics about this instance. type databaseMetricsMiddleware struct { next DatabaseType diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 39655bf46..dadb6639e 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -40,11 +40,13 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log return nil, ErrEmptyPluginName } + // Look for plugin in the plugin catalog pluginMeta, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } + // create a DatabasePluginClient instance db, err := newPluginClient(sys, pluginMeta) if err != nil { return nil, err @@ -76,6 +78,8 @@ var handshakeConfig = plugin.HandshakeConfig{ MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } +// DatabasePlugin implements go-plugin's Plugin interface. It has methods for +// retrieving a server and a client instance of the plugin. type DatabasePlugin struct { impl DatabaseType } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 5c1b41a3d..326e25103 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -8,7 +8,7 @@ import ( ) // NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// DatabaseType implementation in a databasePluginRPCServer object and starts a // RPC server. func NewPluginServer(db DatabaseType) { dbPlugin := &DatabasePlugin{ @@ -35,7 +35,8 @@ func NewPluginServer(db DatabaseType) { // ---- RPC server domain ---- -// databasePluginRPCServer impliments DatabaseType and is run inside a plugin +// databasePluginRPCServer implements an RPC version of DatabaseType and is run +// inside a plugin. It wraps an underlying implementation of DatabaseType. type databasePluginRPCServer struct { impl DatabaseType } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c242aa339..5817f53c2 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +// pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")), @@ -20,32 +21,36 @@ func pathResetConnection(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConnectionReset, + logical.UpdateOperation: b.pathConnectionReset(), }, - HelpSynopsis: pathConfigConnectionHelpSyn, - HelpDescription: pathConfigConnectionHelpDesc, + HelpSynopsis: pathResetConnectionHelpSyn, + HelpDescription: pathResetConnectionHelpDesc, } } -func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil +// pathConnectionReset resets a plugin by closing the existing instance and +// creating a new one. +func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + b.clearConnection(name) + + _, err := b.getOrCreateDBObj(req.Storage, name) + if err != nil { + return nil, err + } + + return nil, nil } - - // Grab the mutex lock - b.Lock() - defer b.Unlock() - - b.clearConnection(name) - - _, err := b.getOrCreateDBObj(req.Storage, name) - if err != nil { - return nil, err - } - - return nil, nil } // pathConfigurePluginConnection returns a configured framework.Path setup to @@ -60,15 +65,17 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { }, "verify_connection": &framework.FieldSchema{ - Type: framework.TypeBool, - Default: true, - Description: `If set, connection_url is verified by actually connecting to the database`, + Type: framework.TypeBool, + Default: true, + Description: `If set, the connection details are verified by + actually connecting to the database`, }, "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, + Description: `The name of a builtin or previously registered + plugin known to vault. This endpoint will create an instance of + that plugin type.`, }, }, @@ -198,16 +205,32 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } const pathConfigConnectionHelpSyn = ` -Configure the connection string to talk to PostgreSQL. +Configure connection details to a database plugin. ` const pathConfigConnectionHelpDesc = ` -This path configures the connection string used to connect to PostgreSQL. -The value of the string can be a URL, or a PG style string in the -format of "user=foo host=bar" etc. +This path configures the connection details used to connect to a particular +database. This path runs the provided plugin name and passes the configured +connection details to the plugin. See the documentation for the plugin specified +for a full list of accepted connection details. -The URL looks like: -"postgresql://user:pass@host:port/dbname" +In addition to the database specific connection details, this endpoing also +accepts: -When configuring the connection string, the backend will verify its validity. + * "plugin_name" (required) - The name of a builtin or previously registered + plugin known to vault. This endpoint will create an instance of that + plugin type. + + * "verify_connection" - A boolean value denoting if the plugin should verify + it is able to connect to the database using the provided connection + details. +` + +const pathResetConnectionHelpSyn = ` +Resets a database plugin. +` + +const pathResetConnectionHelpDesc = ` +This path resets the database connection by closing the existing database plugin +instance and running a new one. ` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 5a16c8926..59584e943 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -19,7 +19,7 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleCreateRead, + logical.ReadOperation: b.pathRoleCreateRead(), }, HelpSynopsis: pathRoleCreateReadHelpSyn, @@ -27,45 +27,47 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { } } -func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) +func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) - // Get the role - role, err := b.Role(req.Storage, name) - if err != nil { - return nil, err + // Get the role + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + } + + b.Lock() + defer b.Unlock() + + // Get the Database object + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + // TODO: return a resp error instead? + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + + expiration := time.Now().Add(role.DefaultTTL) + + // Create the user + username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) + if err != nil { + return nil, err + } + + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + "role": name, + }) + resp.Secret.TTL = role.DefaultTTL + return resp, nil } - if role == nil { - return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil - } - - b.Lock() - defer b.Unlock() - - // Get the Database object - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - // TODO: return a resp error instead? - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) - } - - expiration := time.Now().Add(role.DefaultTTL) - - // Create the user - username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) - if err != nil { - return nil, err - } - - resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ - "username": username, - "password": password, - }, map[string]interface{}{ - "username": username, - "role": name, - }) - resp.Secret.TTL = role.DefaultTTL - return resp, nil } const pathRoleCreateReadHelpSyn = ` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index a6989df24..263a555e6 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -14,7 +14,7 @@ func pathListRoles(b *databaseBackend) *framework.Path { Pattern: "roles/?$", Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ListOperation: b.pathRoleList, + logical.ListOperation: b.pathRoleList(), }, HelpSynopsis: pathRoleHelpSyn, @@ -35,12 +35,13 @@ func pathRoles(b *databaseBackend) *framework.Path { Type: framework.TypeString, Description: "Name of the database this role acts on.", }, - "creation_statements": { - Type: framework.TypeString, - Description: "SQL string to create a user. See help for more info.", + Type: framework.TypeString, + Description: `Statements to be executed to create a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}', + '{{password}}', and '{{expiration}}' values will be substituted.`, }, - "revocation_statements": { Type: framework.TypeString, Description: `Statements to be executed to revoke a user. Must be a semicolon-separated @@ -75,9 +76,9 @@ func pathRoles(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleRead, - logical.UpdateOperation: b.pathRoleCreate, - logical.DeleteOperation: b.pathRoleDelete, + logical.ReadOperation: b.pathRoleRead(), + logical.UpdateOperation: b.pathRoleCreate(), + logical.DeleteOperation: b.pathRoleDelete(), }, HelpSynopsis: pathRoleHelpSyn, @@ -85,101 +86,107 @@ func pathRoles(b *databaseBackend) *framework.Path { } } -func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - err := req.Storage.Delete("role/" + data.Get("name").(string)) - if err != nil { - return nil, err - } +func (b *databaseBackend) pathRoleDelete() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } - return nil, nil -} - -func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - role, err := b.Role(req.Storage, data.Get("name").(string)) - if err != nil { - return nil, err - } - if role == nil { return nil, nil } - - return &logical.Response{ - Data: map[string]interface{}{ - "creation_statements": role.Statements.CreationStatements, - "revocation_statements": role.Statements.RevocationStatements, - "rollback_statements": role.Statements.RollbackStatements, - "renew_statements": role.Statements.RenewStatements, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), - }, - }, nil } -func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - entries, err := req.Storage.List("role/") - if err != nil { - return nil, err - } +func (b *databaseBackend) pathRoleRead() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } - return logical.ListResponse(entries), nil + return &logical.Response{ + Data: map[string]interface{}{ + "creation_statements": role.Statements.CreationStatements, + "revocation_statements": role.Statements.RevocationStatements, + "rollback_statements": role.Statements.RollbackStatements, + "renew_statements": role.Statements.RenewStatements, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), + }, + }, nil + } } -func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse("Empty role name attribute given"), nil +func (b *databaseBackend) pathRoleList() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("role/") + if err != nil { + return nil, err + } + + return logical.ListResponse(entries), nil } +} - dbName := data.Get("db_name").(string) - if dbName == "" { - return logical.ErrorResponse("Empty database name attribute given"), nil +func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty role name attribute given"), nil + } + + dbName := data.Get("db_name").(string) + if dbName == "" { + return logical.ErrorResponse("Empty database name attribute given"), nil + } + + // Get statements + creationStmts := data.Get("creation_statements").(string) + revocationStmts := data.Get("revocation_statements").(string) + rollbackStmts := data.Get("rollback_statements").(string) + renewStmts := data.Get("renew_statements").(string) + + // Get TTLs + defaultTTLRaw := data.Get("default_ttl").(string) + maxTTLRaw := data.Get("max_ttl").(string) + + defaultTTL, err := time.ParseDuration(defaultTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid default_ttl: %s", err)), nil + } + maxTTL, err := time.ParseDuration(maxTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_ttl: %s", err)), nil + } + + statements := dbplugin.Statements{ + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + RenewStatements: renewStmts, + } + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + DBName: dbName, + Statements: statements, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil } - - // Get statements - creationStmts := data.Get("creation_statements").(string) - revocationStmts := data.Get("revocation_statements").(string) - rollbackStmts := data.Get("rollback_statements").(string) - renewStmts := data.Get("renew_statements").(string) - - // Get TTLs - defaultTTLRaw := data.Get("default_ttl").(string) - maxTTLRaw := data.Get("max_ttl").(string) - - defaultTTL, err := time.ParseDuration(defaultTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid default_ttl: %s", err)), nil - } - maxTTL, err := time.ParseDuration(maxTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_ttl: %s", err)), nil - } - - statements := dbplugin.Statements{ - CreationStatements: creationStmts, - RevocationStatements: revocationStmts, - RollbackStatements: rollbackStmts, - RenewStatements: renewStmts, - } - - // TODO: Think about preparing the statments to test. - - // Store it - entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - Statements: statements, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, - }) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } - - return nil, nil } type roleEntry struct { @@ -196,10 +203,14 @@ Manage the roles that can be created with this backend. const pathRoleHelpDesc = ` This path lets you manage the roles that can be created with this backend. -The "sql" parameter customizes the SQL string used to create the role. -This can be a sequence of SQL queries. Some substitution will be done to the -SQL string for certain keys. The names of the variables must be surrounded -by "{{" and "}}" to be replaced. +The "db_name" parameter is required and configures the name of the database +connection to use. + +The "creation_statements" parameter customizes the string used to create the +credentials. This can be a sequence of SQL queries, or other statement formats +for a particular database type. Some substitution will be done to the statement +strings for certain keys. The names of the variables must be surrounded by "{{" +and "}}" to be replaced. * "name" - The random username generated for the DB user. @@ -207,7 +218,7 @@ by "{{" and "}}" to be replaced. * "expiration" - The timestamp when this user will expire. -Example of a decent SQL query to use: +Example of a decent creation_statements for a postgresql database plugin: CREATE ROLE "{{name}}" WITH LOGIN @@ -215,14 +226,17 @@ Example of a decent SQL query to use: VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -Note the above user would be able to access everything in schema public. -For more complex GRANT clauses, see the PostgreSQL manual. - -The "revocation_sql" parameter customizes the SQL string used to revoke a user. -Example of a decent revocation SQL query to use: +The "revocation_statements" parameter customizes the statement string used to +revoke a user. Example of a decent revocation_statements for a postgresql +database plugin: REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; REVOKE USAGE ON SCHEMA public FROM {{name}}; DROP ROLE IF EXISTS {{name}}; + +The "renew_statements" parameter customizes the statement string used to renew a +user. +The "rollback_statements' parameter customizes the statement string used to +rollback a change if needed. ` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 5701e373a..ffc59cf3f 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -14,112 +14,116 @@ func secretCreds(b *databaseBackend) *framework.Secret { Type: SecretCredsType, Fields: map[string]*framework.FieldSchema{}, - Renew: b.secretCredsRenew, - Revoke: b.secretCredsRevoke, + Renew: b.secretCredsRenew(), + Revoke: b.secretCredsRevoke(), } } -func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Get the username from the internal data - usernameRaw, ok := req.Secret.InternalData["username"] - if !ok { - return nil, fmt.Errorf("secret is missing username internal data") - } - username, ok := usernameRaw.(string) - - roleNameRaw, ok := req.Secret.InternalData["role"] - if !ok { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } - - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } - - f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) - resp, err := f(req, d) - if err != nil { - return nil, err - } - - // Grab the read lock - b.Lock() - defer b.Unlock() - - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) - } - - // Make sure we increase the VALID UNTIL endpoint for this user. - if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - err := db.RenewUser(role.Statements, username, expireTime) - if err != nil { - return nil, err +func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") } - } + username, ok := usernameRaw.(string) - return resp, nil -} + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + } -func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Get the username from the internal data - usernameRaw, ok := req.Secret.InternalData["username"] - if !ok { - return nil, fmt.Errorf("secret is missing username internal data") - } - username, ok := usernameRaw.(string) - - var resp *logical.Response - - roleNameRaw, ok := req.Secret.InternalData["role"] - if !ok { - return nil, fmt.Errorf("no role name was provided") - } - - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } - - /* TODO: think about how to handle this case. - if !ok { role, err := b.Role(req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } if role == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) - } else { - revocationSQL = role.RevocationStatement + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } - }*/ - // Grab the read lock - b.Lock() - defer b.Unlock() + f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) + resp, err := f(req, data) + if err != nil { + return nil, err + } - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) + // Grab the read lock + b.Lock() + defer b.Unlock() + + // Get our connection + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) + } + + // Make sure we increase the VALID UNTIL endpoint for this user. + if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { + err := db.RenewUser(role.Statements, username, expireTime) + if err != nil { + return nil, err + } + } + + return resp, nil + } +} + +func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + var resp *logical.Response + + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("no role name was provided") + } + + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + } + + /* TODO: think about how to handle this case. + if !ok { + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + if resp == nil { + resp = &logical.Response{} + } + resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) + } else { + revocationSQL = role.RevocationStatement + } + }*/ + + // Grab the read lock + b.Lock() + defer b.Unlock() + + // Get our connection + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) + } + + err = db.RevokeUser(role.Statements, username) + if err != nil { + return nil, err + } + + return resp, nil } - - err = db.RevokeUser(role.Statements, username) - if err != nil { - return nil, err - } - - return resp, nil } diff --git a/logical/system_view.go b/logical/system_view.go index b69f27090..b6ab14b1f 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -88,11 +88,11 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { } func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { - return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") + return "", errors.New("ResponseWrapData is not implemented in StaticSystemView") } func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { - return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") + return nil, errors.New("LookupPlugin is not implemented in StaticSystemView") } func (d StaticSystemView) MlockDisabled() bool { From faaeb0906502b511d139c97989f16e6f1e5867a9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 09:40:54 -0700 Subject: [PATCH 068/162] Add remaining crud functions to plugin catalog and tests --- helper/builtinplugins/builtin.go | 12 +++ vault/logical_system.go | 31 ++++-- vault/plugin_catalog.go | 47 +++++++++ vault/plugin_catalog_test.go | 166 +++++++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 vault/plugin_catalog_test.go diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index ba3769c90..55da9a97f 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -22,3 +22,15 @@ func (b *builtinPlugins) Get(name string) (func() error, bool) { f, ok := b.plugins[name] return f, ok } + +func (b *builtinPlugins) Keys() []string { + keys := make([]string, len(b.plugins)) + + i := 0 + for k := range b.plugins { + keys[i] = k + i++ + } + + return keys +} diff --git a/vault/logical_system.go b/vault/logical_system.go index f5dbe2aff..fadae02bf 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,7 +63,6 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog", "plugin-catalog/*", }, @@ -694,6 +693,18 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]), HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, + &framework.Path{ + Pattern: "plugin-catalog/$", + + Fields: map[string]*framework.FieldSchema{}, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.handlePluginCatalogList, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), + HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + }, &framework.Path{ Pattern: "plugin-catalog/(?P.+)", @@ -750,6 +761,16 @@ func (b *SystemBackend) invalidate(key string) { } } +func (b *SystemBackend) handlePluginCatalogList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + plugins, err := b.Core.pluginCatalog.List() + if err != nil { + return nil, err + } + + resp := logical.ListResponse(plugins) + return resp, nil +} + func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { pluginName := d.Get("name").(string) if pluginName == "" { @@ -801,16 +822,12 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame if pluginName == "" { return logical.ErrorResponse("missing plugin name"), nil } - plugin, err := b.Core.pluginCatalog.Get(pluginName) + err := b.Core.pluginCatalog.Delete(pluginName) if err != nil { return nil, err } - return &logical.Response{ - Data: map[string]interface{}{ - "plugin": plugin, - }, - }, nil + return nil, nil } // handleAuditedHeaderUpdate creates or overwrites a header entry diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 737f0c26b..264a43d44 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "path/filepath" + "sort" "strings" "sync" @@ -39,6 +40,9 @@ func (c *Core) setupPluginCatalog() error { } func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + c.lock.RLock() + defer c.lock.RUnlock() + // Look for external plugins in the barrier out, err := c.catalogView.Get(name) if err != nil { @@ -68,6 +72,9 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + c.lock.Lock() + defer c.lock.Unlock() + parts := strings.Split(command, " ") command = parts[0] args := parts[1:] @@ -111,3 +118,43 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { } return nil } + +func (c *PluginCatalog) Delete(name string) error { + c.lock.Lock() + defer c.lock.Unlock() + + return c.catalogView.Delete(name) +} + +func (c *PluginCatalog) List() ([]string, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + keys, err := logical.CollectKeys(c.catalogView) + if err != nil { + return nil, err + } + + builtinKeys := builtinplugins.BuiltinPlugins.Keys() + + mapKeys := make(map[string]bool) + + for _, plugin := range keys { + mapKeys[plugin] = true + } + + for _, plugin := range builtinKeys { + mapKeys[plugin] = true + } + + retList := make([]string, len(mapKeys)) + i := 0 + for k := range mapKeys { + retList[i] = k + i++ + } + // sort for consistent ordering of builtin pluings + sort.Strings(retList) + + return retList, nil +} diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go new file mode 100644 index 000000000..e78e7d963 --- /dev/null +++ b/vault/plugin_catalog_test.go @@ -0,0 +1,166 @@ +package vault + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "sort" + "testing" + + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" +) + +func TestPluginCatalog_CRUD(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + core.pluginCatalog.directory = sym + core.pluginCatalog.vaultCommand = "vault" + core.pluginCatalog.vaultSHA256 = []byte{'1'} + + // Get builtin plugin + p, err := core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + expectedBuiltin := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: "vault", + Args: []string{"plugin-exec", "mysql-database-plugin"}, + Sha256: []byte{'1'}, + Builtin: true, + } + + if !reflect.DeepEqual(p, expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) + } + + // Set a plugin, test overwriting a builtin plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // Get the plugin + p, err = core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + expected := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: filepath.Join(sym, filepath.Base(file.Name())), + Args: []string{"--test"}, + Sha256: []byte{'1'}, + Builtin: false, + } + + if !reflect.DeepEqual(p, expected) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expected) + } + + // Delete the plugin + err = core.pluginCatalog.Delete("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + // Get builtin plugin + p, err = core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if !reflect.DeepEqual(p, expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) + } + +} + +func TestPluginCatalog_List(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + core.pluginCatalog.directory = sym + core.pluginCatalog.vaultCommand = "vault" + core.pluginCatalog.vaultSHA256 = []byte{'1'} + + // Get builtin plugins and sort them + builtinKeys := builtinplugins.BuiltinPlugins.Keys() + sort.Strings(builtinKeys) + + // List only builtin plugins + plugins, err := core.pluginCatalog.List() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if len(plugins) != len(builtinKeys) { + t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys), len(plugins)) + } + + for i, p := range builtinKeys { + if !reflect.DeepEqual(plugins[i], p) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i], p) + } + } + + // Set a plugin, test overwriting a builtin plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // Set another plugin + err = core.pluginCatalog.Set("aaaaaaa", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // List the plugins + plugins, err = core.pluginCatalog.List() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if len(plugins) != len(builtinKeys)+1 { + t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys)+1, len(plugins)) + } + + // verify the first plugin is the one we just created. + if !reflect.DeepEqual(plugins[0], "aaaaaaa") { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[0], "aaaaaaa") + } + + // verify the builtin pluings are correct + for i, p := range builtinKeys { + if !reflect.DeepEqual(plugins[i+1], p) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i+1], p) + } + } + +} From c3724c6f17a6f3245da61b1948ab7d517cdea232 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 10:01:36 -0700 Subject: [PATCH 069/162] Add path help and comments for plugin-catalog --- vault/logical_system.go | 25 +++++++++++++++++++++---- vault/plugin_catalog.go | 15 +++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/vault/logical_system.go b/vault/logical_system.go index fadae02bf..2aff5d0e1 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -702,8 +702,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen logical.ListOperation: b.handlePluginCatalogList, }, - HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), - HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, &framework.Path{ Pattern: "plugin-catalog/(?P.+)", @@ -726,8 +726,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen logical.ReadOperation: b.handlePluginCatalogRead, }, - HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), - HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, }, } @@ -2506,4 +2506,21 @@ This path responds to the following HTTP methods. "Lists the headers configured to be audited.", `Returns a list of headers that have been configured to be audited.`, }, + "plugin-catalog": { + `Configures the plugins known to vault`, + ` +This path responds to the following HTTP methods. + GET / + Returns a list of names of configured plugins. + + GET / + Retrieve the metadata for the named plugin. + + PUT / + Add or update plugin. + + DELETE / + Delete the plugin with the given name. + `, + }, } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 264a43d44..b89224780 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -19,6 +19,9 @@ var ( pluginCatalogPrefix = "plugin-catalog/" ) +// PluginCatalog keeps a record of plugins known to vault. External plugins need +// to be registered to the catalog before they can be used in backends. Builtin +// plugins are automatically detected and included in the catalog. type PluginCatalog struct { catalogView *BarrierView directory string @@ -39,6 +42,9 @@ func (c *Core) setupPluginCatalog() error { return nil } +// Get retrieves a plugin with the specified name from the catalog. It first +// looks for external plugins with this name and then looks for builtin plugins. +// It returns a PluginRunner or an error if no plugin was found. func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { c.lock.RLock() defer c.lock.RUnlock() @@ -71,6 +77,8 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { }, nil } +// Set registers a new external plugin with the catalog, or updates an existing +// external plugin. It takes the name, command and SHA256 of the plugin. func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { c.lock.Lock() defer c.lock.Unlock() @@ -119,6 +127,8 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { return nil } +// Delete is used to remove an external plugin from the catalog. Builtin plugins +// can not be deleted. func (c *PluginCatalog) Delete(name string) error { c.lock.Lock() defer c.lock.Unlock() @@ -126,17 +136,22 @@ func (c *PluginCatalog) Delete(name string) error { return c.catalogView.Delete(name) } +// List returns a list of all the known plugin names. If an external and builtin +// plugin share the same name, only one instance of the name will be returned. func (c *PluginCatalog) List() ([]string, error) { c.lock.RLock() defer c.lock.RUnlock() + // Collect keys for external plugins in the barrier. keys, err := logical.CollectKeys(c.catalogView) if err != nil { return nil, err } + // Get the keys for builtin plugins builtinKeys := builtinplugins.BuiltinPlugins.Keys() + // Use a map to unique the two lists mapKeys := make(map[string]bool) for _, plugin := range keys { From 433004f75e7b38c814ca8aaacb41095f15d87479 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 10:39:18 -0700 Subject: [PATCH 070/162] Add test for logical_system plugin-catalog handling --- vault/logical_system_test.go | 93 ++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 15f60a50c..3c808677f 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2,6 +2,11 @@ package vault import ( "crypto/sha256" + "encoding/hex" + "fmt" + "io/ioutil" + "os" + "path/filepath" "reflect" "strings" "testing" @@ -9,6 +14,8 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) @@ -1076,3 +1083,89 @@ func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) { } return c, b, root } + +func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { + c, b, _ := testCoreSystemBackend(t) + // Bootstrap the pluginCatalog + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + c.pluginCatalog.directory = sym + c.pluginCatalog.vaultCommand = "vault" + c.pluginCatalog.vaultSHA256 = []byte{'1'} + + req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") + resp, err := b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(resp.Data["keys"].([]string)) != len(builtinplugins.BuiltinPlugins.Keys()) { + t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.BuiltinPlugins.Keys())) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + expectedBuiltin := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: "vault", + Args: []string{"plugin-exec", "mysql-database-plugin"}, + Sha256: []byte{'1'}, + Builtin: true, + } + + if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) + } + + // Set a plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + req = logical.TestRequest(t, logical.UpdateOperation, "plugin-catalog/test-plugin") + req.Data["sha_256"] = hex.EncodeToString([]byte{'1'}) + req.Data["command"] = command + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + expected := &pluginutil.PluginRunner{ + Name: "test-plugin", + Command: filepath.Join(sym, filepath.Base(file.Name())), + Args: []string{"--test"}, + Sha256: []byte{'1'}, + Builtin: false, + } + if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expected) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expected) + } + + // Delete plugin + req = logical.TestRequest(t, logical.DeleteOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err == nil { + t.Fatalf("expected error, plugin not deleted correctly") + } +} From 3cd5dd18395da125a4f2f118ddf46138dba14dcf Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 14:22:52 -0700 Subject: [PATCH 071/162] Fix RootPaths test --- vault/logical_system_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 3c808677f..c608b3b86 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -32,6 +32,7 @@ func TestSystemBackend_RootPaths(t *testing.T) { "replication/reindex", "rotate", "config/auditing/*", + "plugin-catalog/*", } b := testSystemBackend(t) From 5fac259ae651ed7a0569ccaffde8313e678122ee Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 14:23:15 -0700 Subject: [PATCH 072/162] vendor go-plugin --- vendor/github.com/hashicorp/go-plugin/LICENSE | 353 ++++++++++ .../github.com/hashicorp/go-plugin/README.md | 161 +++++ .../github.com/hashicorp/go-plugin/client.go | 666 ++++++++++++++++++ .../hashicorp/go-plugin/discover.go | 28 + .../github.com/hashicorp/go-plugin/error.go | 24 + .../hashicorp/go-plugin/mux_broker.go | 204 ++++++ .../github.com/hashicorp/go-plugin/plugin.go | 25 + .../github.com/hashicorp/go-plugin/process.go | 24 + .../hashicorp/go-plugin/process_posix.go | 19 + .../hashicorp/go-plugin/process_windows.go | 29 + .../hashicorp/go-plugin/rpc_client.go | 123 ++++ .../hashicorp/go-plugin/rpc_server.go | 185 +++++ .../github.com/hashicorp/go-plugin/server.go | 235 ++++++ .../hashicorp/go-plugin/server_mux.go | 31 + .../github.com/hashicorp/go-plugin/stream.go | 18 + .../github.com/hashicorp/go-plugin/testing.go | 76 ++ vendor/vendor.json | 6 + 17 files changed, 2207 insertions(+) create mode 100644 vendor/github.com/hashicorp/go-plugin/LICENSE create mode 100644 vendor/github.com/hashicorp/go-plugin/README.md create mode 100644 vendor/github.com/hashicorp/go-plugin/client.go create mode 100644 vendor/github.com/hashicorp/go-plugin/discover.go create mode 100644 vendor/github.com/hashicorp/go-plugin/error.go create mode 100644 vendor/github.com/hashicorp/go-plugin/mux_broker.go create mode 100644 vendor/github.com/hashicorp/go-plugin/plugin.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process_posix.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process_windows.go create mode 100644 vendor/github.com/hashicorp/go-plugin/rpc_client.go create mode 100644 vendor/github.com/hashicorp/go-plugin/rpc_server.go create mode 100644 vendor/github.com/hashicorp/go-plugin/server.go create mode 100644 vendor/github.com/hashicorp/go-plugin/server_mux.go create mode 100644 vendor/github.com/hashicorp/go-plugin/stream.go create mode 100644 vendor/github.com/hashicorp/go-plugin/testing.go diff --git a/vendor/github.com/hashicorp/go-plugin/LICENSE b/vendor/github.com/hashicorp/go-plugin/LICENSE new file mode 100644 index 000000000..82b4de97c --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/LICENSE @@ -0,0 +1,353 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. “Contributor” + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. “Contributor Version” + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor’s Contribution. + +1.3. “Contribution” + + means Covered Software of a particular Contributor. + +1.4. “Covered Software” + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. “Incompatible With Secondary Licenses” + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of version + 1.1 or earlier of the License, but not also under the terms of a + Secondary License. + +1.6. “Executable Form” + + means any form of the work other than Source Code Form. + +1.7. “Larger Work” + + means a work that combines Covered Software with other material, in a separate + file or files, that is not Covered Software. + +1.8. “License” + + means this document. + +1.9. “Licensable” + + means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently, any and all of the rights conveyed by + this License. + +1.10. “Modifications” + + means any of the following: + + a. any file in Source Code Form that results from an addition to, deletion + from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. “Patent Claims” of a Contributor + + means any patent claim(s), including without limitation, method, process, + and apparatus claims, in any patent Licensable by such Contributor that + would be infringed, but for the grant of the License, by the making, + using, selling, offering for sale, having made, import, or transfer of + either its Contributions or its Contributor Version. + +1.12. “Secondary License” + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. “Source Code Form” + + means the form of the work preferred for making modifications. + +1.14. “You” (or “Your”) + + means an individual or a legal entity exercising rights under this + License. For legal entities, “You” includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, “control” means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or as + part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its Contributions + or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution become + effective for each Contribution on the date the Contributor first distributes + such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under this + License. No additional rights or licenses will be implied from the distribution + or licensing of Covered Software under this License. Notwithstanding Section + 2.1(b) above, no patent license is granted by a Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party’s + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of its + Contributions. + + This License does not grant any rights in the trademarks, service marks, or + logos of any Contributor (except as may be necessary to comply with the + notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this License + (see Section 10.2) or under the terms of a Secondary License (if permitted + under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its Contributions + are its original creation(s) or it has sufficient rights to grant the + rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under applicable + copyright doctrines of fair use, fair dealing, or other equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under the + terms of this License. You must inform recipients that the Source Code Form + of the Covered Software is governed by the terms of this License, and how + they can obtain a copy of this License. You may not attempt to alter or + restrict the recipients’ rights in the Source Code Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this License, + or sublicense it under different terms, provided that the license for + the Executable Form does not attempt to limit or alter the recipients’ + rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for the + Covered Software. If the Larger Work is a combination of Covered Software + with a work governed by one or more Secondary Licenses, and the Covered + Software is not Incompatible With Secondary Licenses, this License permits + You to additionally distribute such Covered Software under the terms of + such Secondary License(s), so that the recipient of the Larger Work may, at + their option, further distribute the Covered Software under the terms of + either this License or such Secondary License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices (including + copyright notices, patent notices, disclaimers of warranty, or limitations + of liability) contained within the Source Code Form of the Covered + Software, except that You may alter any license notices to the extent + required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on behalf + of any Contributor. You must make it absolutely clear that any such + warranty, support, indemnity, or liability obligation is offered by You + alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, judicial + order, or regulation then You must: (a) comply with the terms of this License + to the maximum extent possible; and (b) describe the limitations and the code + they affect. Such description must be placed in a text file included with all + distributions of the Covered Software under this License. Except to the + extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to + understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing basis, + if such Contributor fails to notify You of the non-compliance by some + reasonable means prior to 60 days after You have come back into compliance. + Moreover, Your grants from a particular Contributor are reinstated on an + ongoing basis if such Contributor notifies You of the non-compliance by + some reasonable means, this is the first time You have received notice of + non-compliance with this License from such Contributor, and You become + compliant prior to 30 days after Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, counter-claims, + and cross-claims) alleging that a Contributor Version directly or + indirectly infringes any patent, then the rights granted to You by any and + all Contributors for the Covered Software under Section 2.1 of this License + shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an “as is” basis, without + warranty of any kind, either expressed, implied, or statutory, including, + without limitation, warranties that the Covered Software is free of defects, + merchantable, fit for a particular purpose or non-infringing. The entire + risk as to the quality and performance of the Covered Software is with You. + Should any Covered Software prove defective in any respect, You (not any + Contributor) assume the cost of any necessary servicing, repair, or + correction. This disclaimer of warranty constitutes an essential part of this + License. No use of any Covered Software is authorized under this License + except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from such + party’s negligence to the extent applicable law prohibits such limitation. + Some jurisdictions do not allow the exclusion or limitation of incidental or + consequential damages, so this exclusion and limitation may not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts of + a jurisdiction where the defendant maintains its principal place of business + and such litigation shall be governed by laws of that jurisdiction, without + reference to its conflict-of-law provisions. Nothing in this Section shall + prevent a party’s ability to bring cross-claims or counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. Any law or regulation which provides that the language of a + contract shall be construed against the drafter shall not be used to construe + this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version of + the License under which You originally received the Covered Software, or + under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a modified + version of this License if you rename the license and remove any + references to the name of the license steward (except to note that such + modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses + If You choose to distribute Source Code Form that is Incompatible With + Secondary Licenses under the terms of this version of the License, the + notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, then +You may include the notice in a location (such as a LICENSE file in a relevant +directory) where a recipient would be likely to look for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - “Incompatible With Secondary Licenses” Notice + + This Source Code Form is “Incompatible + With Secondary Licenses”, as defined by + the Mozilla Public License, v. 2.0. diff --git a/vendor/github.com/hashicorp/go-plugin/README.md b/vendor/github.com/hashicorp/go-plugin/README.md new file mode 100644 index 000000000..2058cfb68 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/README.md @@ -0,0 +1,161 @@ +# Go Plugin System over RPC + +`go-plugin` is a Go (golang) plugin system over RPC. It is the plugin system +that has been in use by HashiCorp tooling for over 3 years. While initially +created for [Packer](https://www.packer.io), it has since been used by +[Terraform](https://www.terraform.io) and [Otto](https://www.ottoproject.io), +with plans to also use it for [Nomad](https://www.nomadproject.io) and +[Vault](https://www.vaultproject.io). + +While the plugin system is over RPC, it is currently only designed to work +over a local [reliable] network. Plugins over a real network are not supported +and will lead to unexpected behavior. + +This plugin system has been used on millions of machines across many different +projects and has proven to be battle hardened and ready for production use. + +## Features + +The HashiCorp plugin system supports a number of features: + +**Plugins are Go interface implementations.** This makes writing and consuming +plugins feel very natural. To a plugin author: you just implement an +interface as if it were going to run in the same process. For a plugin user: +you just use and call functions on an interface as if it were in the same +process. This plugin system handles the communication in between. + +**Complex arguments and return values are supported.** This library +provides APIs for handling complex arguments and return values such +as interfaces, `io.Reader/Writer`, etc. We do this by giving you a library +(`MuxBroker`) for creating new connections between the client/server to +serve additional interfaces or transfer raw data. + +**Bidirectional communication.** Because the plugin system supports +complex arguments, the host process can send it interface implementations +and the plugin can call back into the host process. + +**Built-in Logging.** Any plugins that use the `log` standard library +will have log data automatically sent to the host process. The host +process will mirror this output prefixed with the path to the plugin +binary. This makes debugging with plugins simple. + +**Protocol Versioning.** A very basic "protocol version" is supported that +can be incremented to invalidate any previous plugins. This is useful when +interface signatures are changing, protocol level changes are necessary, +etc. When a protocol version is incompatible, a human friendly error +message is shown to the end user. + +**Stdout/Stderr Syncing.** While plugins are subprocesses, they can continue +to use stdout/stderr as usual and the output will get mirrored back to +the host process. The host process can control what `io.Writer` these +streams go to to prevent this from happening. + +**TTY Preservation.** Plugin subprocesses are connected to the identical +stdin file descriptor as the host process, allowing software that requires +a TTY to work. For example, a plugin can execute `ssh` and even though there +are multiple subprocesses and RPC happening, it will look and act perfectly +to the end user. + +**Host upgrade while a plugin is running.** Plugins can be "reattached" +so that the host process can be upgraded while the plugin is still running. +This requires the host/plugin to know this is possible and daemonize +properly. `NewClient` takes a `ReattachConfig` to determine if and how to +reattach. + +## Architecture + +The HashiCorp plugin system works by launching subprocesses and communicating +over RPC (using standard `net/rpc`). A single connection is made between +any plugin and the host process, and we use a +[connection multiplexing](https://github.com/hashicorp/yamux) +library to multiplex any other connections on top. + +This architecture has a number of benefits: + + * Plugins can't crash your host process: A panic in a plugin doesn't + panic the plugin user. + + * Plugins are very easy to write: just write a Go application and `go build`. + Theoretically you could also use another language as long as it can + communicate the Go `net/rpc` protocol but this hasn't yet been tried. + + * Plugins are very easy to install: just put the binary in a location where + the host will find it (depends on the host but this library also provides + helpers), and the plugin host handles the rest. + + * Plugins can be relatively secure: The plugin only has access to the + interfaces and args given to it, not to the entire memory space of the + process. More security features are planned (see the coming soon section + below). + +## Usage + +To use the plugin system, you must take the following steps. These are +high-level steps that must be done. Examples are available in the +`examples/` directory. + + 1. Choose the interface(s) you want to expose for plugins. + + 2. For each interface, implement an implementation of that interface + that communicates over an `*rpc.Client` (from the standard `net/rpc` + package) for every function call. Likewise, implement the RPC server + struct this communicates to which is then communicating to a real, + concrete implementation. + + 3. Create a `Plugin` implementation that knows how to create the RPC + client/server for a given plugin type. + + 4. Plugin authors call `plugin.Serve` to serve a plugin from the + `main` function. + + 5. Plugin users use `plugin.Client` to launch a subprocess and request + an interface implementation over RPC. + +That's it! In practice, step 2 is the most tedious and time consuming step. +Even so, it isn't very difficult and you can see examples in the `examples/` +directory as well as throughout our various open source projects. + +For complete API documentation, see [GoDoc](https://godoc.org/github.com/hashicorp/go-plugin). + +## Roadmap + +Our plugin system is constantly evolving. As we use the plugin system for +new projects or for new features in existing projects, we constantly find +improvements we can make. + +At this point in time, the roadmap for the plugin system is: + +**Cryptographically Secure Plugins.** We'll implement signing plugins +and loading signed plugins in order to allow Vault to make use of multi-process +in a secure way. + +**Semantic Versioning.** Plugins will be able to implement a semantic version. +This plugin system will give host processes a system for constraining +versions. This is in addition to the protocol versioning already present +which is more for larger underlying changes. + +**Plugin fetching.** We will integrate with [go-getter](https://github.com/hashicorp/go-getter) +to support automatic download + install of plugins. Paired with cryptographically +secure plugins (above), we can make this a safe operation for an amazing +user experience. + +## What About Shared Libraries? + +When we started using plugins (late 2012, early 2013), plugins over RPC +were the only option since Go didn't support dynamic library loading. Today, +Go still doesn't support dynamic library loading, but they do intend to. +Since 2012, our plugin system has stabilized from millions of users using it, +and has many benefits we've come to value greatly. + +For example, we intend to use this plugin system in +[Vault](https://www.vaultproject.io), and dynamic library loading will +simply never be acceptable in Vault for security reasons. That is an extreme +example, but we believe our library system has more upsides than downsides +over dynamic library loading and since we've had it built and tested for years, +we'll likely continue to use it. + +Shared libraries have one major advantage over our system which is much +higher performance. In real world scenarios across our various tools, +we've never required any more performance out of our plugin system and it +has seen very high throughput, so this isn't a concern for us at the moment. + diff --git a/vendor/github.com/hashicorp/go-plugin/client.go b/vendor/github.com/hashicorp/go-plugin/client.go new file mode 100644 index 000000000..b69d41b28 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/client.go @@ -0,0 +1,666 @@ +package plugin + +import ( + "bufio" + "crypto/subtle" + "crypto/tls" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unicode" +) + +// If this is 1, then we've called CleanupClients. This can be used +// by plugin RPC implementations to change error behavior since you +// can expected network connection errors at this point. This should be +// read by using sync/atomic. +var Killed uint32 = 0 + +// This is a slice of the "managed" clients which are cleaned up when +// calling Cleanup +var managedClients = make([]*Client, 0, 5) +var managedClientsLock sync.Mutex + +// Error types +var ( + // ErrProcessNotFound is returned when a client is instantiated to + // reattach to an existing process and it isn't found. + ErrProcessNotFound = errors.New("Reattachment process not found") + + // ErrChecksumsDoNotMatch is returned when binary's checksum doesn't match + // the one provided in the SecureConfig. + ErrChecksumsDoNotMatch = errors.New("checksums did not match") + + // ErrSecureNoChecksum is returned when an empty checksum is provided to the + // SecureConfig. + ErrSecureConfigNoChecksum = errors.New("no checksum provided") + + // ErrSecureNoHash is returned when a nil Hash object is provided to the + // SecureConfig. + ErrSecureConfigNoHash = errors.New("no hash implementation provided") + + // ErrSecureConfigAndReattach is returned when both Reattach and + // SecureConfig are set. + ErrSecureConfigAndReattach = errors.New("only one of Reattach or SecureConfig can be set") +) + +// Client handles the lifecycle of a plugin application. It launches +// plugins, connects to them, dispenses interface implementations, and handles +// killing the process. +// +// Plugin hosts should use one Client for each plugin executable. To +// dispense a plugin type, use the `Client.Client` function, and then +// cal `Dispense`. This awkward API is mostly historical but is used to split +// the client that deals with subprocess management and the client that +// does RPC management. +// +// See NewClient and ClientConfig for using a Client. +type Client struct { + config *ClientConfig + exited bool + doneLogging chan struct{} + l sync.Mutex + address net.Addr + process *os.Process + client *RPCClient +} + +// ClientConfig is the configuration used to initialize a new +// plugin client. After being used to initialize a plugin client, +// that configuration must not be modified again. +type ClientConfig struct { + // HandshakeConfig is the configuration that must match servers. + HandshakeConfig + + // Plugins are the plugins that can be consumed. + Plugins map[string]Plugin + + // One of the following must be set, but not both. + // + // Cmd is the unstarted subprocess for starting the plugin. If this is + // set, then the Client starts the plugin process on its own and connects + // to it. + // + // Reattach is configuration for reattaching to an existing plugin process + // that is already running. This isn't common. + Cmd *exec.Cmd + Reattach *ReattachConfig + + // SecureConfig is configuration for verifying the integrity of the + // executable. It can not be used with Reattach. + SecureConfig *SecureConfig + + // TLSConfig is used to enable TLS on the RPC client. + TLSConfig *tls.Config + + // Managed represents if the client should be managed by the + // plugin package or not. If true, then by calling CleanupClients, + // it will automatically be cleaned up. Otherwise, the client + // user is fully responsible for making sure to Kill all plugin + // clients. By default the client is _not_ managed. + Managed bool + + // The minimum and maximum port to use for communicating with + // the subprocess. If not set, this defaults to 10,000 and 25,000 + // respectively. + MinPort, MaxPort uint + + // StartTimeout is the timeout to wait for the plugin to say it + // has started successfully. + StartTimeout time.Duration + + // If non-nil, then the stderr of the client will be written to here + // (as well as the log). This is the original os.Stderr of the subprocess. + // This isn't the output of synced stderr. + Stderr io.Writer + + // SyncStdout, SyncStderr can be set to override the + // respective os.Std* values in the plugin. Care should be taken to + // avoid races here. If these are nil, then this will automatically be + // hooked up to os.Stdin, Stdout, and Stderr, respectively. + // + // If the default values (nil) are used, then this package will not + // sync any of these streams. + SyncStdout io.Writer + SyncStderr io.Writer +} + +// ReattachConfig is used to configure a client to reattach to an +// already-running plugin process. You can retrieve this information by +// calling ReattachConfig on Client. +type ReattachConfig struct { + Addr net.Addr + Pid int +} + +// SecureConfig is used to configure a client to verify the integrity of an +// executable before running. It does this by verifying the checksum is +// expected. Hash is used to specify the hashing method to use when checksumming +// the file. The configuration is verified by the client by calling the +// SecureConfig.Check() function. +// +// The host process should ensure the checksum was provided by a trusted and +// authoritative source. The binary should be installed in such a way that it +// can not be modified by an unauthorized user between the time of this check +// and the time of execution. +type SecureConfig struct { + Checksum []byte + Hash hash.Hash +} + +// Check takes the filepath to an executable and returns true if the checksum of +// the file matches the checksum provided in the SecureConfig. +func (s *SecureConfig) Check(filePath string) (bool, error) { + if len(s.Checksum) == 0 { + return false, ErrSecureConfigNoChecksum + } + + if s.Hash == nil { + return false, ErrSecureConfigNoHash + } + + file, err := os.Open(filePath) + if err != nil { + return false, err + } + defer file.Close() + + _, err = io.Copy(s.Hash, file) + if err != nil { + return false, err + } + + sum := s.Hash.Sum(nil) + + return subtle.ConstantTimeCompare(sum, s.Checksum) == 1, nil +} + +// This makes sure all the managed subprocesses are killed and properly +// logged. This should be called before the parent process running the +// plugins exits. +// +// This must only be called _once_. +func CleanupClients() { + // Set the killed to true so that we don't get unexpected panics + atomic.StoreUint32(&Killed, 1) + + // Kill all the managed clients in parallel and use a WaitGroup + // to wait for them all to finish up. + var wg sync.WaitGroup + managedClientsLock.Lock() + for _, client := range managedClients { + wg.Add(1) + + go func(client *Client) { + client.Kill() + wg.Done() + }(client) + } + managedClientsLock.Unlock() + + log.Println("[DEBUG] plugin: waiting for all plugin processes to complete...") + wg.Wait() +} + +// Creates a new plugin client which manages the lifecycle of an external +// plugin and gets the address for the RPC connection. +// +// The client must be cleaned up at some point by calling Kill(). If +// the client is a managed client (created with NewManagedClient) you +// can just call CleanupClients at the end of your program and they will +// be properly cleaned. +func NewClient(config *ClientConfig) (c *Client) { + if config.MinPort == 0 && config.MaxPort == 0 { + config.MinPort = 10000 + config.MaxPort = 25000 + } + + if config.StartTimeout == 0 { + config.StartTimeout = 1 * time.Minute + } + + if config.Stderr == nil { + config.Stderr = ioutil.Discard + } + + if config.SyncStdout == nil { + config.SyncStdout = ioutil.Discard + } + if config.SyncStderr == nil { + config.SyncStderr = ioutil.Discard + } + + c = &Client{config: config} + if config.Managed { + managedClientsLock.Lock() + managedClients = append(managedClients, c) + managedClientsLock.Unlock() + } + + return +} + +// Client returns an RPC client for the plugin. +// +// Subsequent calls to this will return the same RPC client. +func (c *Client) Client() (*RPCClient, error) { + addr, err := c.Start() + if err != nil { + return nil, err + } + + c.l.Lock() + defer c.l.Unlock() + + if c.client != nil { + return c.client, nil + } + + // Connect to the client + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + if tcpConn, ok := conn.(*net.TCPConn); ok { + // Make sure to set keep alive so that the connection doesn't die + tcpConn.SetKeepAlive(true) + } + + if c.config.TLSConfig != nil { + conn = tls.Client(conn, c.config.TLSConfig) + } + + // Create the actual RPC client + c.client, err = NewRPCClient(conn, c.config.Plugins) + if err != nil { + conn.Close() + return nil, err + } + + // Begin the stream syncing so that stdin, out, err work properly + err = c.client.SyncStreams( + c.config.SyncStdout, + c.config.SyncStderr) + if err != nil { + c.client.Close() + c.client = nil + return nil, err + } + + return c.client, nil +} + +// Tells whether or not the underlying process has exited. +func (c *Client) Exited() bool { + c.l.Lock() + defer c.l.Unlock() + return c.exited +} + +// End the executing subprocess (if it is running) and perform any cleanup +// tasks necessary such as capturing any remaining logs and so on. +// +// This method blocks until the process successfully exits. +// +// This method can safely be called multiple times. +func (c *Client) Kill() { + // Grab a lock to read some private fields. + c.l.Lock() + process := c.process + addr := c.address + doneCh := c.doneLogging + c.l.Unlock() + + // If there is no process, we never started anything. Nothing to kill. + if process == nil { + return + } + + // We need to check for address here. It is possible that the plugin + // started (process != nil) but has no address (addr == nil) if the + // plugin failed at startup. If we do have an address, we need to close + // the plugin net connections. + graceful := false + if addr != nil { + // Close the client to cleanly exit the process. + client, err := c.Client() + if err == nil { + err = client.Close() + + // If there is no error, then we attempt to wait for a graceful + // exit. If there was an error, we assume that graceful cleanup + // won't happen and just force kill. + graceful = err == nil + if err != nil { + // If there was an error just log it. We're going to force + // kill in a moment anyways. + log.Printf( + "[WARN] plugin: error closing client during Kill: %s", err) + } + } + } + + // If we're attempting a graceful exit, then we wait for a short period + // of time to allow that to happen. To wait for this we just wait on the + // doneCh which would be closed if the process exits. + if graceful { + select { + case <-doneCh: + return + case <-time.After(250 * time.Millisecond): + } + } + + // If graceful exiting failed, just kill it + process.Kill() + + // Wait for the client to finish logging so we have a complete log + <-doneCh +} + +// Starts the underlying subprocess, communicating with it to negotiate +// a port for RPC connections, and returning the address to connect via RPC. +// +// This method is safe to call multiple times. Subsequent calls have no effect. +// Once a client has been started once, it cannot be started again, even if +// it was killed. +func (c *Client) Start() (addr net.Addr, err error) { + c.l.Lock() + defer c.l.Unlock() + + if c.address != nil { + return c.address, nil + } + + // If one of cmd or reattach isn't set, then it is an error. We wrap + // this in a {} for scoping reasons, and hopeful that the escape + // analysis will pop the stock here. + { + cmdSet := c.config.Cmd != nil + attachSet := c.config.Reattach != nil + secureSet := c.config.SecureConfig != nil + if cmdSet == attachSet { + return nil, fmt.Errorf("Only one of Cmd or Reattach must be set") + } + + if secureSet && attachSet { + return nil, ErrSecureConfigAndReattach + } + } + + // Create the logging channel for when we kill + c.doneLogging = make(chan struct{}) + + if c.config.Reattach != nil { + // Verify the process still exists. If not, then it is an error + p, err := os.FindProcess(c.config.Reattach.Pid) + if err != nil { + return nil, err + } + + // Attempt to connect to the addr since on Unix systems FindProcess + // doesn't actually return an error if it can't find the process. + conn, err := net.Dial( + c.config.Reattach.Addr.Network(), + c.config.Reattach.Addr.String()) + if err != nil { + p.Kill() + return nil, ErrProcessNotFound + } + conn.Close() + + // Goroutine to mark exit status + go func(pid int) { + // Wait for the process to die + pidWait(pid) + + // Log so we can see it + log.Printf("[DEBUG] plugin: reattached plugin process exited\n") + + // Mark it + c.l.Lock() + defer c.l.Unlock() + c.exited = true + + // Close the logging channel since that doesn't work on reattach + close(c.doneLogging) + }(p.Pid) + + // Set the address and process + c.address = c.config.Reattach.Addr + c.process = p + + return c.address, nil + } + + env := []string{ + fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue), + fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort), + fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort), + } + + stdout_r, stdout_w := io.Pipe() + stderr_r, stderr_w := io.Pipe() + + cmd := c.config.Cmd + cmd.Env = append(cmd.Env, os.Environ()...) + cmd.Env = append(cmd.Env, env...) + cmd.Stdin = os.Stdin + cmd.Stderr = stderr_w + cmd.Stdout = stdout_w + + if c.config.SecureConfig != nil { + if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil { + return nil, fmt.Errorf("error verifying checksum: %s", err) + } else if !ok { + return nil, ErrChecksumsDoNotMatch + } + } + + log.Printf("[DEBUG] plugin: starting plugin: %s %#v", cmd.Path, cmd.Args) + err = cmd.Start() + if err != nil { + return + } + + // Set the process + c.process = cmd.Process + + // Make sure the command is properly cleaned up if there is an error + defer func() { + r := recover() + + if err != nil || r != nil { + cmd.Process.Kill() + } + + if r != nil { + panic(r) + } + }() + + // Start goroutine to wait for process to exit + exitCh := make(chan struct{}) + go func() { + // Make sure we close the write end of our stderr/stdout so + // that the readers send EOF properly. + defer stderr_w.Close() + defer stdout_w.Close() + + // Wait for the command to end. + cmd.Wait() + + // Log and make sure to flush the logs write away + log.Printf("[DEBUG] plugin: %s: plugin process exited\n", cmd.Path) + os.Stderr.Sync() + + // Mark that we exited + close(exitCh) + + // Set that we exited, which takes a lock + c.l.Lock() + defer c.l.Unlock() + c.exited = true + }() + + // Start goroutine that logs the stderr + go c.logStderr(stderr_r) + + // Start a goroutine that is going to be reading the lines + // out of stdout + linesCh := make(chan []byte) + go func() { + defer close(linesCh) + + buf := bufio.NewReader(stdout_r) + for { + line, err := buf.ReadBytes('\n') + if line != nil { + linesCh <- line + } + + if err == io.EOF { + return + } + } + }() + + // Make sure after we exit we read the lines from stdout forever + // so they don't block since it is an io.Pipe + defer func() { + go func() { + for _ = range linesCh { + } + }() + }() + + // Some channels for the next step + timeout := time.After(c.config.StartTimeout) + + // Start looking for the address + log.Printf("[DEBUG] plugin: waiting for RPC address for: %s", cmd.Path) + select { + case <-timeout: + err = errors.New("timeout while waiting for plugin to start") + case <-exitCh: + err = errors.New("plugin exited before we could connect") + case lineBytes := <-linesCh: + // Trim the line and split by "|" in order to get the parts of + // the output. + line := strings.TrimSpace(string(lineBytes)) + parts := strings.SplitN(line, "|", 4) + if len(parts) < 4 { + err = fmt.Errorf( + "Unrecognized remote plugin message: %s\n\n"+ + "This usually means that the plugin is either invalid or simply\n"+ + "needs to be recompiled to support the latest protocol.", line) + return + } + + // Check the core protocol. Wrapped in a {} for scoping. + { + var coreProtocol int64 + coreProtocol, err = strconv.ParseInt(parts[0], 10, 0) + if err != nil { + err = fmt.Errorf("Error parsing core protocol version: %s", err) + return + } + + if int(coreProtocol) != CoreProtocolVersion { + err = fmt.Errorf("Incompatible core API version with plugin. "+ + "Plugin version: %s, Ours: %d\n\n"+ + "To fix this, the plugin usually only needs to be recompiled.\n"+ + "Please report this to the plugin author.", parts[0], CoreProtocolVersion) + return + } + } + + // Parse the protocol version + var protocol int64 + protocol, err = strconv.ParseInt(parts[1], 10, 0) + if err != nil { + err = fmt.Errorf("Error parsing protocol version: %s", err) + return + } + + // Test the API version + if uint(protocol) != c.config.ProtocolVersion { + err = fmt.Errorf("Incompatible API version with plugin. "+ + "Plugin version: %s, Ours: %d", parts[1], c.config.ProtocolVersion) + return + } + + switch parts[2] { + case "tcp": + addr, err = net.ResolveTCPAddr("tcp", parts[3]) + case "unix": + addr, err = net.ResolveUnixAddr("unix", parts[3]) + default: + err = fmt.Errorf("Unknown address type: %s", parts[3]) + } + } + + c.address = addr + return +} + +// ReattachConfig returns the information that must be provided to NewClient +// to reattach to the plugin process that this client started. This is +// useful for plugins that detach from their parent process. +// +// If this returns nil then the process hasn't been started yet. Please +// call Start or Client before calling this. +func (c *Client) ReattachConfig() *ReattachConfig { + c.l.Lock() + defer c.l.Unlock() + + if c.address == nil { + return nil + } + + if c.config.Cmd != nil && c.config.Cmd.Process == nil { + return nil + } + + // If we connected via reattach, just return the information as-is + if c.config.Reattach != nil { + return c.config.Reattach + } + + return &ReattachConfig{ + Addr: c.address, + Pid: c.config.Cmd.Process.Pid, + } +} + +func (c *Client) logStderr(r io.Reader) { + bufR := bufio.NewReader(r) + for { + line, err := bufR.ReadString('\n') + if line != "" { + c.config.Stderr.Write([]byte(line)) + + line = strings.TrimRightFunc(line, unicode.IsSpace) + log.Printf("[DEBUG] plugin: %s: %s", filepath.Base(c.config.Cmd.Path), line) + } + + if err == io.EOF { + break + } + } + + // Flag that we've completed logging for others + close(c.doneLogging) +} diff --git a/vendor/github.com/hashicorp/go-plugin/discover.go b/vendor/github.com/hashicorp/go-plugin/discover.go new file mode 100644 index 000000000..d22c566ed --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/discover.go @@ -0,0 +1,28 @@ +package plugin + +import ( + "path/filepath" +) + +// Discover discovers plugins that are in a given directory. +// +// The directory doesn't need to be absolute. For example, "." will work fine. +// +// This currently assumes any file matching the glob is a plugin. +// In the future this may be smarter about checking that a file is +// executable and so on. +// +// TODO: test +func Discover(glob, dir string) ([]string, error) { + var err error + + // Make the directory absolute if it isn't already + if !filepath.IsAbs(dir) { + dir, err = filepath.Abs(dir) + if err != nil { + return nil, err + } + } + + return filepath.Glob(filepath.Join(dir, glob)) +} diff --git a/vendor/github.com/hashicorp/go-plugin/error.go b/vendor/github.com/hashicorp/go-plugin/error.go new file mode 100644 index 000000000..22a7baa6a --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/error.go @@ -0,0 +1,24 @@ +package plugin + +// This is a type that wraps error types so that they can be messaged +// across RPC channels. Since "error" is an interface, we can't always +// gob-encode the underlying structure. This is a valid error interface +// implementer that we will push across. +type BasicError struct { + Message string +} + +// NewBasicError is used to create a BasicError. +// +// err is allowed to be nil. +func NewBasicError(err error) *BasicError { + if err == nil { + return nil + } + + return &BasicError{err.Error()} +} + +func (e *BasicError) Error() string { + return e.Message +} diff --git a/vendor/github.com/hashicorp/go-plugin/mux_broker.go b/vendor/github.com/hashicorp/go-plugin/mux_broker.go new file mode 100644 index 000000000..01c45ad7c --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/mux_broker.go @@ -0,0 +1,204 @@ +package plugin + +import ( + "encoding/binary" + "fmt" + "log" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/yamux" +) + +// MuxBroker is responsible for brokering multiplexed connections by unique ID. +// +// It is used by plugins to multiplex multiple RPC connections and data +// streams on top of a single connection between the plugin process and the +// host process. +// +// This allows a plugin to request a channel with a specific ID to connect to +// or accept a connection from, and the broker handles the details of +// holding these channels open while they're being negotiated. +// +// The Plugin interface has access to these for both Server and Client. +// The broker can be used by either (optionally) to reserve and connect to +// new multiplexed streams. This is useful for complex args and return values, +// or anything else you might need a data stream for. +type MuxBroker struct { + nextId uint32 + session *yamux.Session + streams map[uint32]*muxBrokerPending + + sync.Mutex +} + +type muxBrokerPending struct { + ch chan net.Conn + doneCh chan struct{} +} + +func newMuxBroker(s *yamux.Session) *MuxBroker { + return &MuxBroker{ + session: s, + streams: make(map[uint32]*muxBrokerPending), + } +} + +// Accept accepts a connection by ID. +// +// This should not be called multiple times with the same ID at one time. +func (m *MuxBroker) Accept(id uint32) (net.Conn, error) { + var c net.Conn + p := m.getStream(id) + select { + case c = <-p.ch: + close(p.doneCh) + case <-time.After(5 * time.Second): + m.Lock() + defer m.Unlock() + delete(m.streams, id) + + return nil, fmt.Errorf("timeout waiting for accept") + } + + // Ack our connection + if err := binary.Write(c, binary.LittleEndian, id); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +// AcceptAndServe is used to accept a specific stream ID and immediately +// serve an RPC server on that stream ID. This is used to easily serve +// complex arguments. +// +// The served interface is always registered to the "Plugin" name. +func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) { + conn, err := m.Accept(id) + if err != nil { + log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err) + return + } + + serve(conn, "Plugin", v) +} + +// Close closes the connection and all sub-connections. +func (m *MuxBroker) Close() error { + return m.session.Close() +} + +// Dial opens a connection by ID. +func (m *MuxBroker) Dial(id uint32) (net.Conn, error) { + // Open the stream + stream, err := m.session.OpenStream() + if err != nil { + return nil, err + } + + // Write the stream ID onto the wire. + if err := binary.Write(stream, binary.LittleEndian, id); err != nil { + stream.Close() + return nil, err + } + + // Read the ack that we connected. Then we're off! + var ack uint32 + if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { + stream.Close() + return nil, err + } + if ack != id { + stream.Close() + return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) + } + + return stream, nil +} + +// NextId returns a unique ID to use next. +// +// It is possible for very long-running plugin hosts to wrap this value, +// though it would require a very large amount of RPC calls. In practice +// we've never seen it happen. +func (m *MuxBroker) NextId() uint32 { + return atomic.AddUint32(&m.nextId, 1) +} + +// Run starts the brokering and should be executed in a goroutine, since it +// blocks forever, or until the session closes. +// +// Uses of MuxBroker never need to call this. It is called internally by +// the plugin host/client. +func (m *MuxBroker) Run() { + for { + stream, err := m.session.AcceptStream() + if err != nil { + // Once we receive an error, just exit + break + } + + // Read the stream ID from the stream + var id uint32 + if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { + stream.Close() + continue + } + + // Initialize the waiter + p := m.getStream(id) + select { + case p.ch <- stream: + default: + } + + // Wait for a timeout + go m.timeoutWait(id, p) + } +} + +func (m *MuxBroker) getStream(id uint32) *muxBrokerPending { + m.Lock() + defer m.Unlock() + + p, ok := m.streams[id] + if ok { + return p + } + + m.streams[id] = &muxBrokerPending{ + ch: make(chan net.Conn, 1), + doneCh: make(chan struct{}), + } + return m.streams[id] +} + +func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) { + // Wait for the stream to either be picked up and connected, or + // for a timeout. + timeout := false + select { + case <-p.doneCh: + case <-time.After(5 * time.Second): + timeout = true + } + + m.Lock() + defer m.Unlock() + + // Delete the stream so no one else can grab it + delete(m.streams, id) + + // If we timed out, then check if we have a channel in the buffer, + // and if so, close it. + if timeout { + select { + case s := <-p.ch: + s.Close() + } + } +} diff --git a/vendor/github.com/hashicorp/go-plugin/plugin.go b/vendor/github.com/hashicorp/go-plugin/plugin.go new file mode 100644 index 000000000..37c8fd653 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/plugin.go @@ -0,0 +1,25 @@ +// The plugin package exposes functions and helpers for communicating to +// plugins which are implemented as standalone binary applications. +// +// plugin.Client fully manages the lifecycle of executing the application, +// connecting to it, and returning the RPC client for dispensing plugins. +// +// plugin.Serve fully manages listeners to expose an RPC server from a binary +// that plugin.Client can connect to. +package plugin + +import ( + "net/rpc" +) + +// Plugin is the interface that is implemented to serve/connect to an +// inteface implementation. +type Plugin interface { + // Server should return the RPC server compatible struct to serve + // the methods that the Client calls over net/rpc. + Server(*MuxBroker) (interface{}, error) + + // Client returns an interface implementation for the plugin you're + // serving that communicates to the server end of the plugin. + Client(*MuxBroker, *rpc.Client) (interface{}, error) +} diff --git a/vendor/github.com/hashicorp/go-plugin/process.go b/vendor/github.com/hashicorp/go-plugin/process.go new file mode 100644 index 000000000..88c999a58 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process.go @@ -0,0 +1,24 @@ +package plugin + +import ( + "time" +) + +// pidAlive checks whether a pid is alive. +func pidAlive(pid int) bool { + return _pidAlive(pid) +} + +// pidWait blocks for a process to exit. +func pidWait(pid int) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if !pidAlive(pid) { + break + } + } + + return nil +} diff --git a/vendor/github.com/hashicorp/go-plugin/process_posix.go b/vendor/github.com/hashicorp/go-plugin/process_posix.go new file mode 100644 index 000000000..70ba546bf --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process_posix.go @@ -0,0 +1,19 @@ +// +build !windows + +package plugin + +import ( + "os" + "syscall" +) + +// _pidAlive tests whether a process is alive or not by sending it Signal 0, +// since Go otherwise has no way to test this. +func _pidAlive(pid int) bool { + proc, err := os.FindProcess(pid) + if err == nil { + err = proc.Signal(syscall.Signal(0)) + } + + return err == nil +} diff --git a/vendor/github.com/hashicorp/go-plugin/process_windows.go b/vendor/github.com/hashicorp/go-plugin/process_windows.go new file mode 100644 index 000000000..9f7b01809 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process_windows.go @@ -0,0 +1,29 @@ +package plugin + +import ( + "syscall" +) + +const ( + // Weird name but matches the MSDN docs + exit_STILL_ACTIVE = 259 + + processDesiredAccess = syscall.STANDARD_RIGHTS_READ | + syscall.PROCESS_QUERY_INFORMATION | + syscall.SYNCHRONIZE +) + +// _pidAlive tests whether a process is alive or not +func _pidAlive(pid int) bool { + h, err := syscall.OpenProcess(processDesiredAccess, false, uint32(pid)) + if err != nil { + return false + } + + var ec uint32 + if e := syscall.GetExitCodeProcess(h, &ec); e != nil { + return false + } + + return ec == exit_STILL_ACTIVE +} diff --git a/vendor/github.com/hashicorp/go-plugin/rpc_client.go b/vendor/github.com/hashicorp/go-plugin/rpc_client.go new file mode 100644 index 000000000..29f9bf063 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/rpc_client.go @@ -0,0 +1,123 @@ +package plugin + +import ( + "fmt" + "io" + "net" + "net/rpc" + + "github.com/hashicorp/yamux" +) + +// RPCClient connects to an RPCServer over net/rpc to dispense plugin types. +type RPCClient struct { + broker *MuxBroker + control *rpc.Client + plugins map[string]Plugin + + // These are the streams used for the various stdout/err overrides + stdout, stderr net.Conn +} + +// NewRPCClient creates a client from an already-open connection-like value. +// Dial is typically used instead. +func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) { + // Create the yamux client so we can multiplex + mux, err := yamux.Client(conn, nil) + if err != nil { + conn.Close() + return nil, err + } + + // Connect to the control stream. + control, err := mux.Open() + if err != nil { + mux.Close() + return nil, err + } + + // Connect stdout, stderr streams + stdstream := make([]net.Conn, 2) + for i, _ := range stdstream { + stdstream[i], err = mux.Open() + if err != nil { + mux.Close() + return nil, err + } + } + + // Create the broker and start it up + broker := newMuxBroker(mux) + go broker.Run() + + // Build the client using our broker and control channel. + return &RPCClient{ + broker: broker, + control: rpc.NewClient(control), + plugins: plugins, + stdout: stdstream[0], + stderr: stdstream[1], + }, nil +} + +// SyncStreams should be called to enable syncing of stdout, +// stderr with the plugin. +// +// This will return immediately and the syncing will continue to happen +// in the background. You do not need to launch this in a goroutine itself. +// +// This should never be called multiple times. +func (c *RPCClient) SyncStreams(stdout io.Writer, stderr io.Writer) error { + go copyStream("stdout", stdout, c.stdout) + go copyStream("stderr", stderr, c.stderr) + return nil +} + +// Close closes the connection. The client is no longer usable after this +// is called. +func (c *RPCClient) Close() error { + // Call the control channel and ask it to gracefully exit. If this + // errors, then we save it so that we always return an error but we + // want to try to close the other channels anyways. + var empty struct{} + returnErr := c.control.Call("Control.Quit", true, &empty) + + // Close the other streams we have + if err := c.control.Close(); err != nil { + return err + } + if err := c.stdout.Close(); err != nil { + return err + } + if err := c.stderr.Close(); err != nil { + return err + } + if err := c.broker.Close(); err != nil { + return err + } + + // Return back the error we got from Control.Quit. This is very important + // since we MUST return non-nil error if this fails so that Client.Kill + // will properly try a process.Kill. + return returnErr +} + +func (c *RPCClient) Dispense(name string) (interface{}, error) { + p, ok := c.plugins[name] + if !ok { + return nil, fmt.Errorf("unknown plugin type: %s", name) + } + + var id uint32 + if err := c.control.Call( + "Dispenser.Dispense", name, &id); err != nil { + return nil, err + } + + conn, err := c.broker.Dial(id) + if err != nil { + return nil, err + } + + return p.Client(c.broker, rpc.NewClient(conn)) +} diff --git a/vendor/github.com/hashicorp/go-plugin/rpc_server.go b/vendor/github.com/hashicorp/go-plugin/rpc_server.go new file mode 100644 index 000000000..3984dc891 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/rpc_server.go @@ -0,0 +1,185 @@ +package plugin + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "net/rpc" + "sync" + + "github.com/hashicorp/yamux" +) + +// RPCServer listens for network connections and then dispenses interface +// implementations over net/rpc. +// +// After setting the fields below, they shouldn't be read again directly +// from the structure which may be reading/writing them concurrently. +type RPCServer struct { + Plugins map[string]Plugin + + // Stdout, Stderr are what this server will use instead of the + // normal stdin/out/err. This is because due to the multi-process nature + // of our plugin system, we can't use the normal process values so we + // make our own custom one we pipe across. + Stdout io.Reader + Stderr io.Reader + + // DoneCh should be set to a non-nil channel that will be closed + // when the control requests the RPC server to end. + DoneCh chan<- struct{} + + lock sync.Mutex +} + +// Accept accepts connections on a listener and serves requests for +// each incoming connection. Accept blocks; the caller typically invokes +// it in a go statement. +func (s *RPCServer) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Printf("[ERR] plugin: plugin server: %s", err) + return + } + + go s.ServeConn(conn) + } +} + +// ServeConn runs a single connection. +// +// ServeConn blocks, serving the connection until the client hangs up. +func (s *RPCServer) ServeConn(conn io.ReadWriteCloser) { + // First create the yamux server to wrap this connection + mux, err := yamux.Server(conn, nil) + if err != nil { + conn.Close() + log.Printf("[ERR] plugin: error creating yamux server: %s", err) + return + } + + // Accept the control connection + control, err := mux.Accept() + if err != nil { + mux.Close() + if err != io.EOF { + log.Printf("[ERR] plugin: error accepting control connection: %s", err) + } + + return + } + + // Connect the stdstreams (in, out, err) + stdstream := make([]net.Conn, 2) + for i, _ := range stdstream { + stdstream[i], err = mux.Accept() + if err != nil { + mux.Close() + log.Printf("[ERR] plugin: accepting stream %d: %s", i, err) + return + } + } + + // Copy std streams out to the proper place + go copyStream("stdout", stdstream[0], s.Stdout) + go copyStream("stderr", stdstream[1], s.Stderr) + + // Create the broker and start it up + broker := newMuxBroker(mux) + go broker.Run() + + // Use the control connection to build the dispenser and serve the + // connection. + server := rpc.NewServer() + server.RegisterName("Control", &controlServer{ + server: s, + }) + server.RegisterName("Dispenser", &dispenseServer{ + broker: broker, + plugins: s.Plugins, + }) + server.ServeConn(control) +} + +// done is called internally by the control server to trigger the +// doneCh to close which is listened to by the main process to cleanly +// exit. +func (s *RPCServer) done() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.DoneCh != nil { + close(s.DoneCh) + s.DoneCh = nil + } +} + +// dispenseServer dispenses variousinterface implementations for Terraform. +type controlServer struct { + server *RPCServer +} + +func (c *controlServer) Quit( + null bool, response *struct{}) error { + // End the server + c.server.done() + + // Always return true + *response = struct{}{} + + return nil +} + +// dispenseServer dispenses variousinterface implementations for Terraform. +type dispenseServer struct { + broker *MuxBroker + plugins map[string]Plugin +} + +func (d *dispenseServer) Dispense( + name string, response *uint32) error { + // Find the function to create this implementation + p, ok := d.plugins[name] + if !ok { + return fmt.Errorf("unknown plugin type: %s", name) + } + + // Create the implementation first so we know if there is an error. + impl, err := p.Server(d.broker) + if err != nil { + // We turn the error into an errors error so that it works across RPC + return errors.New(err.Error()) + } + + // Reserve an ID for our implementation + id := d.broker.NextId() + *response = id + + // Run the rest in a goroutine since it can only happen once this RPC + // call returns. We wait for a connection for the plugin implementation + // and serve it. + go func() { + conn, err := d.broker.Accept(id) + if err != nil { + log.Printf("[ERR] go-plugin: plugin dispense error: %s: %s", name, err) + return + } + + serve(conn, "Plugin", impl) + }() + + return nil +} + +func serve(conn io.ReadWriteCloser, name string, v interface{}) { + server := rpc.NewServer() + if err := server.RegisterName(name, v); err != nil { + log.Printf("[ERR] go-plugin: plugin dispense error: %s", err) + return + } + + server.ServeConn(conn) +} diff --git a/vendor/github.com/hashicorp/go-plugin/server.go b/vendor/github.com/hashicorp/go-plugin/server.go new file mode 100644 index 000000000..782a4e119 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/server.go @@ -0,0 +1,235 @@ +package plugin + +import ( + "crypto/tls" + "errors" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/signal" + "runtime" + "strconv" + "sync/atomic" +) + +// CoreProtocolVersion is the ProtocolVersion of the plugin system itself. +// We will increment this whenever we change any protocol behavior. This +// will invalidate any prior plugins but will at least allow us to iterate +// on the core in a safe way. We will do our best to do this very +// infrequently. +const CoreProtocolVersion = 1 + +// HandshakeConfig is the configuration used by client and servers to +// handshake before starting a plugin connection. This is embedded by +// both ServeConfig and ClientConfig. +// +// In practice, the plugin host creates a HandshakeConfig that is exported +// and plugins then can easily consume it. +type HandshakeConfig struct { + // ProtocolVersion is the version that clients must match on to + // agree they can communicate. This should match the ProtocolVersion + // set on ClientConfig when using a plugin. + ProtocolVersion uint + + // MagicCookieKey and value are used as a very basic verification + // that a plugin is intended to be launched. This is not a security + // measure, just a UX feature. If the magic cookie doesn't match, + // we show human-friendly output. + MagicCookieKey string + MagicCookieValue string +} + +// ServeConfig configures what sorts of plugins are served. +type ServeConfig struct { + // HandshakeConfig is the configuration that must match clients. + HandshakeConfig + + // Plugins are the plugins that are served. + Plugins map[string]Plugin + + // TLSProvider is a function that returns a configured tls.Config. + TLSProvider func() (*tls.Config, error) +} + +// Serve serves the plugins given by ServeConfig. +// +// Serve doesn't return until the plugin is done being executed. Any +// errors will be outputted to the log. +// +// This is the method that plugins should call in their main() functions. +func Serve(opts *ServeConfig) { + // Validate the handshake config + if opts.MagicCookieKey == "" || opts.MagicCookieValue == "" { + fmt.Fprintf(os.Stderr, + "Misconfigured ServeConfig given to serve this plugin: no magic cookie\n"+ + "key or value was set. Please notify the plugin author and report\n"+ + "this as a bug.\n") + os.Exit(1) + } + + // First check the cookie + if os.Getenv(opts.MagicCookieKey) != opts.MagicCookieValue { + fmt.Fprintf(os.Stderr, + "This binary is a plugin. These are not meant to be executed directly.\n"+ + "Please execute the program that consumes these plugins, which will\n"+ + "load any plugins automatically\n") + os.Exit(1) + } + + // Logging goes to the original stderr + log.SetOutput(os.Stderr) + + // Create our new stdout, stderr files. These will override our built-in + // stdout/stderr so that it works across the stream boundary. + stdout_r, stdout_w, err := os.Pipe() + if err != nil { + fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err) + os.Exit(1) + } + stderr_r, stderr_w, err := os.Pipe() + if err != nil { + fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err) + os.Exit(1) + } + + // Register a listener so we can accept a connection + listener, err := serverListener() + if err != nil { + log.Printf("[ERR] plugin: plugin init: %s", err) + return + } + + if opts.TLSProvider != nil { + tlsConfig, err := opts.TLSProvider() + if err != nil { + log.Printf("[ERR] plugin: plugin tls init: %s", err) + return + } + listener = tls.NewListener(listener, tlsConfig) + } + defer listener.Close() + + // Create the channel to tell us when we're done + doneCh := make(chan struct{}) + + // Create the RPC server to dispense + server := &RPCServer{ + Plugins: opts.Plugins, + Stdout: stdout_r, + Stderr: stderr_r, + DoneCh: doneCh, + } + + // Output the address and service name to stdout so that core can bring it up. + log.Printf("[DEBUG] plugin: plugin address: %s %s\n", + listener.Addr().Network(), listener.Addr().String()) + fmt.Printf("%d|%d|%s|%s\n", + CoreProtocolVersion, + opts.ProtocolVersion, + listener.Addr().Network(), + listener.Addr().String()) + os.Stdout.Sync() + + // Eat the interrupts + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + go func() { + var count int32 = 0 + for { + <-ch + newCount := atomic.AddInt32(&count, 1) + log.Printf( + "[DEBUG] plugin: received interrupt signal (count: %d). Ignoring.", + newCount) + } + }() + + // Set our new out, err + os.Stdout = stdout_w + os.Stderr = stderr_w + + // Serve + go server.Accept(listener) + + // Wait for the graceful exit + <-doneCh +} + +func serverListener() (net.Listener, error) { + if runtime.GOOS == "windows" { + return serverListener_tcp() + } + + return serverListener_unix() +} + +func serverListener_tcp() (net.Listener, error) { + minPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MIN_PORT"), 10, 32) + if err != nil { + return nil, err + } + + maxPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MAX_PORT"), 10, 32) + if err != nil { + return nil, err + } + + for port := minPort; port <= maxPort; port++ { + address := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + return nil, errors.New("Couldn't bind plugin TCP listener") +} + +func serverListener_unix() (net.Listener, error) { + tf, err := ioutil.TempFile("", "plugin") + if err != nil { + return nil, err + } + path := tf.Name() + + // Close the file and remove it because it has to not exist for + // the domain socket. + if err := tf.Close(); err != nil { + return nil, err + } + if err := os.Remove(path); err != nil { + return nil, err + } + + l, err := net.Listen("unix", path) + if err != nil { + return nil, err + } + + // Wrap the listener in rmListener so that the Unix domain socket file + // is removed on close. + return &rmListener{ + Listener: l, + Path: path, + }, nil +} + +// rmListener is an implementation of net.Listener that forwards most +// calls to the listener but also removes a file as part of the close. We +// use this to cleanup the unix domain socket on close. +type rmListener struct { + net.Listener + Path string +} + +func (l *rmListener) Close() error { + // Close the listener itself + if err := l.Listener.Close(); err != nil { + return err + } + + // Remove the file + return os.Remove(l.Path) +} diff --git a/vendor/github.com/hashicorp/go-plugin/server_mux.go b/vendor/github.com/hashicorp/go-plugin/server_mux.go new file mode 100644 index 000000000..033079ea0 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/server_mux.go @@ -0,0 +1,31 @@ +package plugin + +import ( + "fmt" + "os" +) + +// ServeMuxMap is the type that is used to configure ServeMux +type ServeMuxMap map[string]*ServeConfig + +// ServeMux is like Serve, but serves multiple types of plugins determined +// by the argument given on the command-line. +// +// This command doesn't return until the plugin is done being executed. Any +// errors are logged or output to stderr. +func ServeMux(m ServeMuxMap) { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, + "Invoked improperly. This is an internal command that shouldn't\n"+ + "be manually invoked.\n") + os.Exit(1) + } + + opts, ok := m[os.Args[1]] + if !ok { + fmt.Fprintf(os.Stderr, "Unknown plugin: %s\n", os.Args[1]) + os.Exit(1) + } + + Serve(opts) +} diff --git a/vendor/github.com/hashicorp/go-plugin/stream.go b/vendor/github.com/hashicorp/go-plugin/stream.go new file mode 100644 index 000000000..1d547aaaa --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/stream.go @@ -0,0 +1,18 @@ +package plugin + +import ( + "io" + "log" +) + +func copyStream(name string, dst io.Writer, src io.Reader) { + if src == nil { + panic(name + ": src is nil") + } + if dst == nil { + panic(name + ": dst is nil") + } + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + log.Printf("[ERR] plugin: stream copy '%s' error: %s", name, err) + } +} diff --git a/vendor/github.com/hashicorp/go-plugin/testing.go b/vendor/github.com/hashicorp/go-plugin/testing.go new file mode 100644 index 000000000..9086a1b45 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/testing.go @@ -0,0 +1,76 @@ +package plugin + +import ( + "bytes" + "net" + "net/rpc" + "testing" +) + +// The testing file contains test helpers that you can use outside of +// this package for making it easier to test plugins themselves. + +// TestConn is a helper function for returning a client and server +// net.Conn connected to each other. +func TestConn(t *testing.T) (net.Conn, net.Conn) { + // Listen to any local port. This listener will be closed + // after a single connection is established. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Start a goroutine to accept our client connection + var serverConn net.Conn + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + defer l.Close() + var err error + serverConn, err = l.Accept() + if err != nil { + t.Fatalf("err: %s", err) + } + }() + + // Connect to the server + clientConn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server side to acknowledge it has connected + <-doneCh + + return clientConn, serverConn +} + +// TestRPCConn returns a rpc client and server connected to each other. +func TestRPCConn(t *testing.T) (*rpc.Client, *rpc.Server) { + clientConn, serverConn := TestConn(t) + + server := rpc.NewServer() + go server.ServeConn(serverConn) + + client := rpc.NewClient(clientConn) + return client, server +} + +// TestPluginRPCConn returns a plugin RPC client and server that are connected +// together and configured. +func TestPluginRPCConn(t *testing.T, ps map[string]Plugin) (*RPCClient, *RPCServer) { + // Create two net.Conns we can use to shuttle our control connection + clientConn, serverConn := TestConn(t) + + // Start up the server + server := &RPCServer{Plugins: ps, Stdout: new(bytes.Buffer), Stderr: new(bytes.Buffer)} + go server.ServeConn(serverConn) + + // Connect the client to the server + client, err := NewRPCClient(clientConn, ps) + if err != nil { + t.Fatalf("err: %s", err) + } + + return client, server +} diff --git a/vendor/vendor.json b/vendor/vendor.json index ee93f5e23..ffcc7c5d4 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -804,6 +804,12 @@ "revision": "ed905158d87462226a13fe39ddf685ea65f1c11f", "revisionTime": "2016-12-16T18:43:04Z" }, + { + "checksumSHA1": "FOLPOFo4xuUaErsL99EC8azEUjw=", + "path": "github.com/hashicorp/go-plugin", + "revision": "b6691c5cfe7f0ec984114b056889cc90e51e38d0", + "revisionTime": "2017-04-12T21:16:38Z" + }, { "checksumSHA1": "ErJHGU6AVPZM9yoY/xV11TwSjQs=", "path": "github.com/hashicorp/go-retryablehttp", From a9a05f5bba5ddb9afd588b174ecd63b37f3f71f9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 16:41:06 -0700 Subject: [PATCH 073/162] Update Type() to return an error --- builtin/logical/database/backend.go | 2 +- builtin/logical/database/dbplugin/client.go | 10 +++++----- .../logical/database/dbplugin/databasemiddleware.go | 4 ++-- builtin/logical/database/dbplugin/plugin.go | 12 +++++++++--- builtin/logical/database/dbplugin/plugin_test.go | 12 ++++++------ builtin/logical/database/dbplugin/server.go | 5 +++-- 6 files changed, 26 insertions(+), 19 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 618ffac6f..c8f9ad854 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -162,5 +162,5 @@ as secret backends, including but not limited to: cassandra, msslq, mysql, postgres After mounting this backend, configure it using the endpoints within -the "database/dbs/" path. +the "database/config/" path. ` diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 5bdc3a01a..93db86595 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn return nil, err } - // We should have a Greeter now! This feels like a normal interface + // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. databaseRPC := raw.(*databasePluginRPCClient) + // Wrap RPC implimentation in DatabasePluginClient return &DatabasePluginClient{ client: client, databasePluginRPCClient: databaseRPC, @@ -70,12 +71,11 @@ type databasePluginRPCClient struct { client *rpc.Client } -func (dr *databasePluginRPCClient) Type() string { +func (dr *databasePluginRPCClient) Type() (string, error) { var dbType string - //TODO: catch error - dr.client.Call("Plugin.Type", struct{}{}, &dbType) + err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) - return fmt.Sprintf("plugin-%s", dbType) + return fmt.Sprintf("plugin-%s", dbType), err } func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 2137cd9c3..e28a8741e 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -18,7 +18,7 @@ type databaseTracingMiddleware struct { typeStr string } -func (mw *databaseTracingMiddleware) Type() string { +func (mw *databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } @@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct { typeStr string } -func (mw *databaseMetricsMiddleware) Type() string { +func (mw *databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index dadb6639e..5e6ce939b 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -2,6 +2,7 @@ package dbplugin import ( "errors" + "fmt" "net/rpc" "time" @@ -16,7 +17,7 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { - Type() string + Type() (string, error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error @@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log return nil, err } + typeStr, err := db.Type() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %s", err) + } + // Wrap with metrics middleware db = &databaseMetricsMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, } // Wrap with tracing middleware db = &databaseTracingMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, logger: logger, } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 7909bbd4e..1587ba24a 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -19,7 +19,7 @@ type mockPlugin struct { users map[string][]string } -func (m *mockPlugin) Type() string { return "mock" } +func (m *mockPlugin) Type() (string, error) { return "mock", nil } func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { err = errors.New("err") if usernamePrefix == "" || expiration.IsZero() { @@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) delete(m.users, username) return nil } -func (m *mockPlugin) Initialize(conf map[string]interface{}) error { +func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { return err @@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(connectionDetails) + err = dbRaw.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 326e25103..3a3e23394 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -42,8 +42,9 @@ type databasePluginRPCServer struct { } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = ds.impl.Type() - return nil + var err error + *resp, err = ds.impl.Type() + return err } func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { From 0cfe1ea81c47ecdadabe176002c0ec3e65c0bf90 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 17:35:02 -0700 Subject: [PATCH 074/162] Cleanup path files --- builtin/logical/database/backend.go | 10 +-- builtin/logical/database/dbplugin/plugin.go | 9 --- .../database/path_config_connection.go | 71 +++++++++++-------- builtin/logical/database/path_roles.go | 1 + command/{plugin-exec.go => plugin_exec.go} | 0 5 files changed, 46 insertions(+), 45 deletions(-) rename command/{plugin-exec.go => plugin_exec.go} (100%) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index c8f9ad854..2ce759526 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -12,7 +12,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -const databaseConfigPath = "database/dbs/" +const databaseConfigPath = "database/config/" // DatabaseConfig is used by the Factory function to configure a DatabaseType // object. @@ -32,12 +32,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), - PathsSpecial: &logical.Paths{ - Root: []string{ - "dbs/plugin/*", - }, - }, - Paths: []*framework.Path{ pathConfigurePluginConnection(&b), pathListRoles(&b), @@ -90,7 +84,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } - entry, err := s.Get(fmt.Sprintf("dbs/%s", name)) + entry, err := s.Get(fmt.Sprintf("config/%s", name)) if err != nil { return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 5e6ce939b..61de0fe8c 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -1,7 +1,6 @@ package dbplugin import ( - "errors" "fmt" "net/rpc" "time" @@ -11,10 +10,6 @@ import ( log "github.com/mgutz/logxi/v1" ) -var ( - ErrEmptyPluginName = errors.New("empty plugin name") -) - // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() (string, error) @@ -37,10 +32,6 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { - if pluginName == "" { - return nil, ErrEmptyPluginName - } - // Look for plugin in the plugin catalog pluginMeta, err := sys.LookupPlugin(pluginName) if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 5817f53c2..f69c7761b 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "errors" "fmt" "github.com/fatih/structs" @@ -9,6 +10,11 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +var ( + respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyName = logical.ErrorResponse("Empty name attribute given") +) + // pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ @@ -16,7 +22,7 @@ func pathResetConnection(b *databaseBackend) *framework.Path { Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, - Description: "Name of this DB type", + Description: "Name of this database connection", }, }, @@ -35,15 +41,17 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } // Grab the mutex lock b.Lock() defer b.Unlock() + // Close plugin and delete the entry in the connections cache. b.clearConnection(name) + // Execute plugin again, we don't need the object so throw away. _, err := b.getOrCreateDBObj(req.Storage, name) if err != nil { return nil, err @@ -61,14 +69,7 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, - Description: "Name of this DB type", - }, - - "verify_connection": &framework.FieldSchema{ - Type: framework.TypeBool, - Default: true, - Description: `If set, the connection details are verified by - actually connecting to the database`, + Description: "Name of this database connection", }, "plugin_name": &framework.FieldSchema{ @@ -77,6 +78,13 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { plugin known to vault. This endpoint will create an instance of that plugin type.`, }, + + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If true, the connection details are verified by + actually connecting to the database. Defaults to true.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -94,10 +102,13 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) + if name == "" { + return respErrEmptyName, nil + } - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") + return nil, errors.New("failed to read connection configuration") } if entry == nil { return nil, nil @@ -118,12 +129,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } - err := req.Storage.Delete(fmt.Sprintf("dbs/%s", name)) + err := req.Storage.Delete(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to delete connection configuration") + return nil, errors.New("failed to delete connection configuration") } b.Lock() @@ -134,9 +145,9 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { if err != nil { return nil, err } - } - delete(b.connections, name) + delete(b.connections, name) + } return nil, nil } @@ -146,22 +157,22 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // both builtin and plugin database types. func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - - config := &DatabaseConfig{ - ConnectionDetails: data.Raw, - PluginName: data.Get("plugin_name").(string), + pluginName := data.Get("plugin_name").(string) + if pluginName == "" { + return respErrEmptyPluginName, nil } name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } verifyConnection := data.Get("verify_connection").(bool) - // Grab the mutex lock - b.Lock() - defer b.Unlock() + config := &DatabaseConfig{ + ConnectionDetails: data.Raw, + PluginName: pluginName, + } db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { @@ -174,6 +185,10 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + // Grab the mutex lock + b.Lock() + defer b.Unlock() + if _, ok := b.connections[name]; ok { // Close and remove the old connection err := b.connections[name].Close() @@ -189,7 +204,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.connections[name] = db // Store it - entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) if err != nil { return nil, err } @@ -198,7 +213,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } resp := &logical.Response{} - resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any.") return resp, nil } @@ -221,7 +236,7 @@ accepts: plugin known to vault. This endpoint will create an instance of that plugin type. - * "verify_connection" - A boolean value denoting if the plugin should verify + * "verify_connection" (default: true) - A boolean value denoting if the plugin should verify it is able to connect to the database using the provided connection details. ` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 263a555e6..b3393b1ba 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -109,6 +109,7 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc { return &logical.Response{ Data: map[string]interface{}{ + "db_name": role.DBName, "creation_statements": role.Statements.CreationStatements, "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, diff --git a/command/plugin-exec.go b/command/plugin_exec.go similarity index 100% rename from command/plugin-exec.go rename to command/plugin_exec.go From cfe25e2a00dc27f2ba935434e0431d56e52164a0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 17:35:53 -0700 Subject: [PATCH 075/162] Add comments to the plugin runner --- helper/pluginutil/runner.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 4d66d8706..a57abad0e 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -17,20 +17,28 @@ var ( PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) +// Looker defines the plugin Lookup function that looks into the plugin catalog +// for availible plugins and returns a PluginRunner type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +// Wrapper interface defines the functions needed by the runner to wrap the +// metadata needed to run a plugin process. This includes looking up Mlock +// configuration and wrapping data in a respose wrapped token. type Wrapper interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) MlockDisabled() bool } +// LookWrapper defines the functions for both Looker and Wrapper type LookWrapper interface { Looker Wrapper } +// PluginRunner defines the metadata needed to run a plugin securely with +// go-plugin. type PluginRunner struct { Name string `json:"name"` Command string `json:"command"` @@ -39,6 +47,8 @@ type PluginRunner struct { Builtin bool `json:"builtin"` } +// Run takes a wrapper instance, and the go-plugin paramaters and executes a +// plugin. func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate CACertBytes, CACert, CAKey, err := GenerateCACert() @@ -87,6 +97,8 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } +// OptionallyEnableMlock determines if mlock should be called, and if so enables +// mlock. func OptionallyEnableMlock() error { if os.Getenv(PluginMlockEnabled) == "true" { return mlock.LockMemory() From 883c80540a051ed3d22888253b9fe965ad419267 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 10:33:34 -0700 Subject: [PATCH 076/162] Add allowed_roles parameter and checks --- builtin/logical/database/backend.go | 36 +++---- builtin/logical/database/backend_test.go | 101 ++++++++++++++++++ .../database/path_config_connection.go | 31 +++++- builtin/logical/database/path_role_create.go | 12 +++ 4 files changed, 158 insertions(+), 22 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 2ce759526..e57fa19c1 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -14,15 +14,6 @@ import ( const databaseConfigPath = "database/config/" -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` -} - func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -84,16 +75,8 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } - entry, err := s.Get(fmt.Sprintf("config/%s", name)) + config, err := b.DatabaseConfig(s, name) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) - } - if entry == nil { - return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) - } - - var config DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -112,6 +95,23 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } +func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) { + entry, err := s.Get(fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + return &config, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 5b3a0db42..2615577fd 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -130,6 +130,7 @@ func TestBackend_config_connection(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": configData, + "allowed_roles": []string{}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(configReq) @@ -306,6 +307,7 @@ func TestBackend_connectionCrud(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": data, + "allowed_roles": []string{}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(req) @@ -484,6 +486,105 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatal("Expected response to be nil") } } +func TestBackend_allowedRoles(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "allow, allowed", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a denied and an allowed role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/denied", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/allowed", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds from denied role, should fail + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/denied", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != logical.ErrPermissionDenied { + t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) + } + + // Get creds from allowed role, should work. + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/allowed", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } +} func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { var d struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f69c7761b..2a0022b4d 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,15 +6,26 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) var ( - respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyPluginName = logical.ErrorResponse("Empty plugin name") respErrEmptyName = logical.ErrorResponse("Empty name attribute given") ) +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"` +} + // pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ @@ -75,15 +86,22 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `The name of a builtin or previously registered - plugin known to vault. This endpoint will create an instance of - that plugin type.`, + plugin known to vault. This endpoint will create an instance of + that plugin type.`, }, "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If true, the connection details are verified by - actually connecting to the database. Defaults to true.`, + actually connecting to the database. Defaults to true.`, + }, + + "allowed_roles": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Comma separated list of the role names allowed to + get creds from this database connection. If not set all roles + are allowed.`, }, }, @@ -169,9 +187,14 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { verifyConnection := data.Get("verify_connection").(bool) + // Pasrse and dedupe allowed roles from a comma separated string. + allowedRolesRaw := data.Get("allowed_roles").(string) + allowedRoles := strutil.ParseDedupAndSortStrings(allowedRolesRaw, ",") + config := &DatabaseConfig{ ConnectionDetails: data.Raw, PluginName: pluginName, + AllowedRoles: allowedRoles, } db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 59584e943..631802dff 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -40,6 +41,17 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil } + dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName) + if err != nil { + return nil, err + } + + // If role name isn't in the database's allowed roles, send back a + // permission denied. + if len(dbConfig.AllowedRoles) > 0 && !strutil.StrListContains(dbConfig.AllowedRoles, name) { + return nil, logical.ErrPermissionDenied + } + b.Lock() defer b.Unlock() From 8a3ef906d596e355b275259a3d33c8ceb4f9255d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 11:22:53 -0700 Subject: [PATCH 077/162] Update the plugin directory logic --- command/server.go | 7 +++++++ command/server/config.go | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/command/server.go b/command/server.go index 17f40d165..c402d09d5 100644 --- a/command/server.go +++ b/command/server.go @@ -284,6 +284,13 @@ func (c *ServerCommand) Run(args []string) int { return 1 } coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") + err = os.Mkdir(coreConfig.PluginDirectory, 0700) + if err != nil && !os.IsExist(err) { + c.Ui.Output(fmt.Sprintf( + "Error making default plugin directory: %v", err)) + return 1 + } + } var disableClustering bool diff --git a/command/server/config.go b/command/server/config.go index 4821a29ba..dad485928 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -273,6 +273,11 @@ func (c *Config) Merge(c2 *Config) *Config { result.EnableUI = c2.EnableUI } + result.PluginDirectory = c.PluginDirectory + if c2.PluginDirectory != "" { + result.PluginDirectory = c2.PluginDirectory + } + return result } From 4e9f89430c7fb2dbe5131c4a6afa9b9b100216c5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 13:48:32 -0700 Subject: [PATCH 078/162] Move plugins into main vault repo --- helper/builtinplugins/builtin.go | 4 +- .../mssql/mssql-database-plugin/main.go | 16 + plugins/database/mssql/mssql.go | 268 ++++++++++++++ plugins/database/mssql/mssql_test.go | 173 +++++++++ .../mysql/mysql-database-plugin/main.go | 16 + plugins/database/mysql/mysql.go | 183 ++++++++++ plugins/database/mysql/mysql_test.go | 200 +++++++++++ .../postgresql-database-plugin/main.go | 16 + plugins/database/postgresql/postgresql.go | 337 ++++++++++++++++++ .../database/postgresql/postgresql_test.go | 308 ++++++++++++++++ plugins/helper/database/connutil/cassandra.go | 172 +++++++++ plugins/helper/database/connutil/connutil.go | 21 ++ plugins/helper/database/connutil/sql.go | 131 +++++++ .../helper/database/credsutil/cassandra.go | 37 ++ .../helper/database/credsutil/credsutil.go | 12 + plugins/helper/database/credsutil/sql.go | 43 +++ plugins/helper/database/dbutil/dbutil.go | 20 ++ 17 files changed, 1955 insertions(+), 2 deletions(-) create mode 100644 plugins/database/mssql/mssql-database-plugin/main.go create mode 100644 plugins/database/mssql/mssql.go create mode 100644 plugins/database/mssql/mssql_test.go create mode 100644 plugins/database/mysql/mysql-database-plugin/main.go create mode 100644 plugins/database/mysql/mysql.go create mode 100644 plugins/database/mysql/mysql_test.go create mode 100644 plugins/database/postgresql/postgresql-database-plugin/main.go create mode 100644 plugins/database/postgresql/postgresql.go create mode 100644 plugins/database/postgresql/postgresql_test.go create mode 100644 plugins/helper/database/connutil/cassandra.go create mode 100644 plugins/helper/database/connutil/connutil.go create mode 100644 plugins/helper/database/connutil/sql.go create mode 100644 plugins/helper/database/credsutil/cassandra.go create mode 100644 plugins/helper/database/credsutil/credsutil.go create mode 100644 plugins/helper/database/credsutil/sql.go create mode 100644 plugins/helper/database/dbutil/dbutil.go diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 55da9a97f..beedbb15b 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,8 +1,8 @@ package builtinplugins import ( - "github.com/hashicorp/vault-plugins/database/mysql" - "github.com/hashicorp/vault-plugins/database/postgresql" + "github.com/hashicorp/vault/plugins/database/mysql" + "github.com/hashicorp/vault/plugins/database/postgresql" ) var BuiltinPlugins *builtinPlugins = &builtinPlugins{ diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go new file mode 100644 index 000000000..ead1cf842 --- /dev/null +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mssql" +) + +func main() { + err := mssql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go new file mode 100644 index 000000000..567a095b6 --- /dev/null +++ b/plugins/database/mssql/mssql.go @@ -0,0 +1,268 @@ +package mssql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const msSQLTypeName = "mssql" + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MSSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = msSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MSSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() (string, error) { + return msSQLTypeName, nil +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go new file mode 100644 index 000000000..bc182f26f --- /dev/null +++ b/plugins/database/mssql/mssql_test.go @@ -0,0 +1,173 @@ +package mssql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mssql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go new file mode 100644 index 000000000..c0ec75c9c --- /dev/null +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mysql" +) + +func main() { + err := mysql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go new file mode 100644 index 000000000..ea14a6782 --- /dev/null +++ b/plugins/database/mysql/mysql.go @@ -0,0 +1,183 @@ +package mysql + +import ( + "database/sql" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const defaultMysqlRevocationStmts = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` +const mySQLTypeName = "mysql" + +type MySQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MySQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MySQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +func (m *MySQL) Type() (string, error) { + return mySQLTypeName, nil +} + +func (m *MySQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (m *MySQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// NOOP +func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + return nil +} + +func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the read lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + revocationStmts := statements.RevocationStatements + // Use a default SQL statement for revocation if one cannot be fetched from the role + if revocationStmts == "" { + revocationStmts = defaultMysqlRevocationStmts + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go new file mode 100644 index 000000000..2b1f27291 --- /dev/null +++ b/plugins/database/mysql/mysql_test.go @@ -0,0 +1,200 @@ +package mysql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMySQLImagePull sync.Once +) + +func prepareMySQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MYSQL_URL") != "" { + return func() {}, os.Getenv("MYSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("mysql", "latest", []string{"MYSQL_ROOT_PASSWORD=secret"}) + if err != nil { + t.Fatalf("Could not start local MySQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("root:secret@(localhost:%s)/mysql?parseTime=true", resource.GetPort("3306/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mysql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MySQL docker container: %s", err) + } + + return +} + +func TestMySQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMySQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMySQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + statements.CreationStatements = testMySQLRoleWildCard + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = testMySQLRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "root:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mysql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMySQLRoleWildCard = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +` +const testMySQLRevocationSQL = ` +REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; +DROP USER '{{name}}'@'%'; +` diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go new file mode 100644 index 000000000..9b9b813c4 --- /dev/null +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/postgresql" +) + +func main() { + err := postgresql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go new file mode 100644 index 000000000..b8449f549 --- /dev/null +++ b/plugins/database/postgresql/postgresql.go @@ -0,0 +1,337 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" + "github.com/lib/pq" +) + +const postgreSQLTypeName string = "postgres" + +func New() *PostgreSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = postgreSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &PostgreSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +type PostgreSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func (p *PostgreSQL) Type() (string, error) { + return postgreSQLTypeName, nil +} + +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + // Grab the lock + p.Lock() + defer p.Unlock() + + username, err = p.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = p.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Get the connection + db, err := p.getConnection() + if err != nil { + return "", "", err + + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + + } + defer func() { + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + + } + + return username, password, nil +} + +func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + db, err := p.getConnection() + if err != nil { + return err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return err + } + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expirationStr) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + if statements.RevocationStatements == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, statements.RevocationStatements) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go new file mode 100644 index 000000000..c7ccc8ee8 --- /dev/null +++ b/plugins/database/postgresql/postgresql_test.go @@ -0,0 +1,308 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testPostgresImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("postgres", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func TestPostgreSQL_Initialize(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPostgreSQL_CreateUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + statements.CreationStatements = testPostgresReadOnlyRole + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RenewUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Sleep longer than the inital expiration time + time.Sleep(2 * time.Second) + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RevokeUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = defaultPostgresRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("postgres", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testPostgresRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testPostgresBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultPostgresRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; +REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + +DROP ROLE IF EXISTS "{{name}}"; +` diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go new file mode 100644 index 000000000..305bc6e3d --- /dev/null +++ b/plugins/helper/database/connutil/cassandra.go @@ -0,0 +1,172 @@ +package connutil + +import ( + "crypto/tls" + "fmt" + "strings" + "sync" + "time" + + "github.com/mitchellh/mapstructure" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +// CassandraConnectionProducer implements ConnectionProducer and provides an +// interface for cassandra databases to make connections. +type CassandraConnectionProducer struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + + Initialized bool + session *gocql.Session + sync.Mutex +} + +func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.Initialized = true + + if verifyConnection { + if _, err := c.connection(); err != nil { + return fmt.Errorf("error Initalizing Connection: %s", err) + } + } + return nil +} + +func (c *CassandraConnectionProducer) connection() (interface{}, error) { + if !c.Initialized { + return nil, errNotInitialized + } + + // If we already have a DB, return it + if c.session != nil { + return c.session, nil + } + + session, err := c.createSession() + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (c *CassandraConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.session != nil { + c.session.Close() + } + + c.session = nil + + return nil +} + +func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: c.Username, + Password: c.Password, + } + + clusterConfig.ProtoVersion = c.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second + + if c.TLS { + var tlsConfig *tls.Config + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey + } + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = c.InsecureTLS + + if c.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("error creating session: %s", err) + } + + // Set consistency + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) + if err != nil { + return nil, err + } + + session.SetConsistency(consistencyValue) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("error validating connection info: %s", err) + } + + return session, nil +} diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go new file mode 100644 index 000000000..6de3299e3 --- /dev/null +++ b/plugins/helper/database/connutil/connutil.go @@ -0,0 +1,21 @@ +package connutil + +import ( + "errors" + "sync" +) + +var ( + errNotInitialized = errors.New("connection has not been initalized") +) + +// ConnectionProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods dealing with individual database +// connections and is used in all the builtin database types. +type ConnectionProducer interface { + Close() error + Initialize(map[string]interface{}, bool) error + Connection() (interface{}, error) + + sync.Locker +} diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go new file mode 100644 index 000000000..0bfc5f9f6 --- /dev/null +++ b/plugins/helper/database/connutil/sql.go @@ -0,0 +1,131 @@ +package connutil + +import ( + "database/sql" + "fmt" + "strings" + "sync" + "time" + + // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" +) + +// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases +type SQLConnectionProducer struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + + Type string + MaxConnectionLifetime time.Duration + Initialized bool + db *sql.DB + sync.Mutex +} + +func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + + if c.MaxOpenConnections == 0 { + c.MaxOpenConnections = 2 + } + + if c.MaxIdleConnections == 0 { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxIdleConnections > c.MaxOpenConnections { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxConnectionLifetimeRaw == "" { + c.MaxConnectionLifetimeRaw = "0s" + } + + c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + if err != nil { + return fmt.Errorf("invalid max_connection_lifetime: %s", err) + } + + if verifyConnection { + if _, err := c.Connection(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + + if err := c.db.Ping(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + } + + c.Initialized = true + + return nil +} + +func (c *SQLConnectionProducer) Connection() (interface{}, error) { + // If we already have a DB, test it and return + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + c.db.Close() + } + + // For mssql backend, switch to sqlserver instead + dbType := c.Type + if c.Type == "mssql" { + dbType = "sqlserver" + } + + // Otherwise, attempt to make connection + conn := c.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } + + var err error + c.db, err = sql.Open(dbType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + c.db.SetMaxOpenConns(c.MaxOpenConnections) + c.db.SetMaxIdleConns(c.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + + return c.db, nil +} + +// Close attempts to close the connection +func (c *SQLConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.db != nil { + c.db.Close() + } + + c.db = nil + + return nil +} diff --git a/plugins/helper/database/credsutil/cassandra.go b/plugins/helper/database/credsutil/cassandra.go new file mode 100644 index 000000000..7ab5630b5 --- /dev/null +++ b/plugins/helper/database/credsutil/cassandra.go @@ -0,0 +1,37 @@ +package credsutil + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// CassandraCredentialsProducer implements CredentialsProducer and provides an +// interface for cassandra databases to generate user information. +type CassandraCredentialsProducer struct{} + +func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return "", nil +} diff --git a/plugins/helper/database/credsutil/credsutil.go b/plugins/helper/database/credsutil/credsutil.go new file mode 100644 index 000000000..7f388a0f7 --- /dev/null +++ b/plugins/helper/database/credsutil/credsutil.go @@ -0,0 +1,12 @@ +package credsutil + +import "time" + +// CredentialsProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods for generating user information for a +// particular database type and is used in all the builtin database types. +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Time) (string, error) +} diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go new file mode 100644 index 000000000..23e98102f --- /dev/null +++ b/plugins/helper/database/credsutil/sql.go @@ -0,0 +1,43 @@ +package credsutil + +import ( + "fmt" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. +type SQLCredentialsProducer struct { + DisplayNameLen int + UsernameLen int +} + +func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, error) { + if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen { + displayName = displayName[:scp.DisplayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { + username = username[:scp.UsernameLen] + } + + return username, nil +} + +func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return ttl.Format("2006-01-02 15:04:05-0700"), nil +} diff --git a/plugins/helper/database/dbutil/dbutil.go b/plugins/helper/database/dbutil/dbutil.go new file mode 100644 index 000000000..e80273b7f --- /dev/null +++ b/plugins/helper/database/dbutil/dbutil.go @@ -0,0 +1,20 @@ +package dbutil + +import ( + "errors" + "fmt" + "strings" +) + +var ( + ErrEmptyCreationStatement = errors.New("empty creation statements") +) + +// Query templates a query for us. +func QueryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} From 189909931be8a0fd7755485920cb8dd5dcada1d5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 14:30:15 -0700 Subject: [PATCH 079/162] Move mssql to be an acceptance test --- plugins/database/mssql/mssql_test.go | 62 +++++++--------------------- 1 file changed, 14 insertions(+), 48 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index bc182f26f..2bca0a7b8 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -11,56 +11,17 @@ import ( "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/plugins/helper/database/connutil" - dockertest "gopkg.in/ory-am/dockertest.v3" ) var ( testMSQLImagePull sync.Once ) -func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { - if os.Getenv("MSSQL_URL") != "" { - return func() {}, os.Getenv("MSSQL_URL") - } - - pool, err := dockertest.NewPool("") - if err != nil { - t.Fatalf("Failed to connect to docker: %s", err) - } - - resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) - if err != nil { - t.Fatalf("Could not start local MSSQL docker container: %s", err) - } - - cleanup = func() { - err := pool.Purge(resource) - if err != nil { - t.Fatalf("Failed to cleanup local container: %s", err) - } - } - - retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) - - // exponential backoff-retry - if err = pool.Retry(func() error { - var err error - var db *sql.DB - db, err = sql.Open("mssql", retURL) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - t.Fatalf("Could not connect to MSSQL docker container: %s", err) - } - - return -} - func TestMSSQL_Initialize(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -85,8 +46,10 @@ func TestMSSQL_Initialize(t *testing.T) { } func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -119,8 +82,10 @@ func TestMSSQL_CreateUser(t *testing.T) { } func TestMSSQL_RevokeUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -158,7 +123,8 @@ func TestMSSQL_RevokeUser(t *testing.T) { func testCredsExist(t testing.TB, connURL, username, password string) error { // Log in with the new creds - connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1) + parts := strings.Split(connURL, "@") + connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1]) db, err := sql.Open("mssql", connURL) if err != nil { return err From 3b4768c5fbf39984b82bcbc97a1a9ed2bb86f712 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 14:40:59 -0700 Subject: [PATCH 080/162] Only run mssql acceptance test when running as VAULT_ACC=1 --- plugins/database/mssql/mssql_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 2bca0a7b8..512033bd7 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -18,7 +18,7 @@ var ( ) func TestMSSQL_Initialize(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") @@ -46,7 +46,7 @@ func TestMSSQL_Initialize(t *testing.T) { } func TestMSSQL_CreateUser(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") @@ -82,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) { } func TestMSSQL_RevokeUser(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") From 09cdea92fd8b8fa47cccd811286ca60d3bf70d2a Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Tue, 18 Apr 2017 17:32:08 -0400 Subject: [PATCH 081/162] Adding explicit database to sp_msloginmappings call (#2611) --- plugins/database/mssql/mssql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 567a095b6..b0e0ab6d4 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -187,7 +187,7 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro // we need to drop the database users before we can drop the login and the role // This isn't done in a transaction because even if we fail along the way, // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) if err != nil { return err } From fcbcc22bd94b1a5596a3aac3cb137e5969524694 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 19 Apr 2017 11:19:29 -0700 Subject: [PATCH 082/162] Fix cassandra deps breakage --- plugins/helper/database/connutil/cassandra.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 305bc6e3d..028c6814f 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -143,7 +143,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { } clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, + Config: tlsConfig, } } From 2ab159569d1ff3725036b42a2346a1e518029423 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 19 Apr 2017 15:46:07 -0700 Subject: [PATCH 083/162] Use the same TLS cert for the server and client --- helper/pluginutil/runner.go | 6 +- helper/pluginutil/tls.go | 116 +++++++----------------------------- helper/strutil/strutil.go | 13 ++++ 3 files changed, 39 insertions(+), 96 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index a57abad0e..bbc5ab99b 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -51,20 +51,20 @@ type PluginRunner struct { // plugin. func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate - CACertBytes, CACert, CAKey, err := GenerateCACert() + certBytes, key, err := GenerateCert() if err != nil { return nil, err } // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey) + clientTLSConfig, err := CreateClientTLSConfig(certBytes, key) if err != nil { return nil, err } // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey) + wrapToken, err := WrapServerConfig(wrapper, certBytes, key) if err != nil { return nil, err } diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index c7aa42ee6..d4c0946e4 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -29,58 +29,19 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -// GenerateCACert returns a CA cert used to later sign the certificates for the -// plugin client and server. -func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - return nil, nil, nil, err - } - - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - // 30 years of single-active uptime ought to be enough for anybody - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) - if err != nil { - return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) - } - - caCert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - return certBytes, caCert, key, nil -} - // generateSignedCert is used internally to create certificates for the plugin // client and server. These certs are signed by the given CA Cert and Key. -func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { +func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, err + } + host, err := uuid.GenerateUUID() if err != nil { - return nil, nil, nil, err + return nil, nil, err } - host = "localhost" + template := &x509.Certificate{ Subject: pkix.Name{ CommonName: host, @@ -94,48 +55,38 @@ func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]by SerialNumber: big.NewInt(mathrand.Int63()), NotBefore: time.Now().Add(-30 * time.Second), NotAfter: time.Now().Add(262980 * time.Hour), + IsCA: true, } - clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) if err != nil { - return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) + return nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) } - certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) - } - - clientCert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - return certBytes, clientCert, clientKey, nil + return certBytes, key, nil } // CreateClientTLSConfig creates a signed certificate and returns a configured // TLS config. -func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (*tls.Config, error) { - clientCertBytes, clientCert, clientKey, err := generateSignedCert(CACert, CAKey) +func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { + clientCert, err := x509.ParseCertificate(certBytes) if err != nil { - return nil, err + return nil, fmt.Errorf("error parsing generated plugin certificate: %v", err) } cert := tls.Certificate{ - Certificate: [][]byte{clientCertBytes}, - PrivateKey: clientKey, + Certificate: [][]byte{certBytes}, + PrivateKey: key, Leaf: clientCert, } clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(CACert) + clientCertPool.AddCert(clientCert) tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: clientCertPool, - ClientCAs: clientCertPool, - ServerName: CACert.Subject.CommonName, + ServerName: clientCert.Subject.CommonName, MinVersion: tls.VersionTLS12, } @@ -146,19 +97,14 @@ func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (* // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys Wrapper, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { - serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) - if err != nil { - return "", err - } - rawKey, err := x509.MarshalECPrivateKey(serverKey) +func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { + rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err } wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ - "CACert": CACertBytes, - "ServerCert": serverCertBytes, + "ServerCert": certBytes, "ServerKey": rawKey, }, time.Second*10, true) @@ -217,22 +163,6 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errors.New("error during token unwrap request secret is nil") } - // Retrieve and parse the CA Certificate - CABytesRaw, ok := secret.Data["CACert"].(string) - if !ok { - return nil, errors.New("error unmarshalling CA certificate") - } - - CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - CACert, err := x509.ParseCertificate(CABytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - // Retrieve and parse the server's certificate serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) if !ok { @@ -267,7 +197,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { // Add CA cert to the cert pool caCertPool := x509.NewCertPool() - caCertPool.AddCert(CACert) + caCertPool.AddCert(serverCert) // Build a certificate object out of the server's cert and private key. cert := tls.Certificate{ diff --git a/helper/strutil/strutil.go b/helper/strutil/strutil.go index 7c7f64d3d..986928e0e 100644 --- a/helper/strutil/strutil.go +++ b/helper/strutil/strutil.go @@ -29,6 +29,19 @@ func StrListSubset(super, sub []string) bool { return true } +// Parses a comma separated list of strings into a slice of strings. +// The return slice will be sorted and will not contain duplicate or +// empty items. +func ParseDedupAndSortStrings(input string, sep string) []string { + input = strings.TrimSpace(input) + parsed := []string{} + if input == "" { + // Don't return nil + return parsed + } + return RemoveDuplicates(strings.Split(input, sep), false) +} + // Parses a comma separated list of strings into a slice of strings. // The return slice will be sorted and will not contain duplicate or // empty items. The values will be converted to lower case. From 6f9d1783707c0d64465ada8bff05dc7a05fc93b2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 20 Apr 2017 18:46:41 -0700 Subject: [PATCH 084/162] Calls to builtin plugins now go directly to the implementation instead of go-plugin --- builtin/logical/database/backend_test.go | 5 +- builtin/logical/database/dbplugin/plugin.go | 27 ++++++-- cli/commands.go | 6 -- command/plugin_exec.go | 66 ------------------- command/server.go | 61 +++++------------ helper/builtinplugins/builtin.go | 12 ++-- helper/pluginutil/runner.go | 11 ++-- plugins/database/mysql/mysql.go | 11 ++-- plugins/database/mysql/mysql_test.go | 9 ++- plugins/database/postgresql/postgresql.go | 12 ++-- .../database/postgresql/postgresql_test.go | 13 ++-- vault/core.go | 12 +--- vault/plugin_catalog.go | 28 ++++---- 13 files changed, 94 insertions(+), 179 deletions(-) delete mode 100644 command/plugin_exec.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 2615577fd..2ece767fc 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -11,10 +11,10 @@ import ( "testing" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins/database/postgresql" "github.com/hashicorp/vault/vault" "github.com/lib/pq" "github.com/mitchellh/mapstructure" @@ -91,8 +91,7 @@ func TestBackend_PluginMain(t *testing.T) { return } - f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin") - f() + postgresql.Run() } func TestBackend_config_connection(t *testing.T) { diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 61de0fe8c..9a6691fba 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -33,15 +33,32 @@ type Statements struct { // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { // Look for plugin in the plugin catalog - pluginMeta, err := sys.LookupPlugin(pluginName) + pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } - // create a DatabasePluginClient instance - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err + var db DatabaseType + if pluginRunner.Builtin { + // Plugin is builtin so we can retrieve an instance of the interface + // from the pluginRunner. Then cast it to a DatabaseType. + dbRaw, err := pluginRunner.BuiltinFactory() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %s", err) + } + + var ok bool + db, ok = dbRaw.(DatabaseType) + if !ok { + return nil, fmt.Errorf("unsuported database type: %s", pluginName) + } + + } else { + // create a DatabasePluginClient instance + db, err = newPluginClient(sys, pluginRunner) + if err != nil { + return nil, err + } } typeStr, err := db.Type() diff --git a/cli/commands.go b/cli/commands.go index e7545ca90..13f7c8b25 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -331,11 +331,5 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { Ui: metaPtr.Ui, }, nil }, - - "plugin-exec": func() (cli.Command, error) { - return &command.PluginExec{ - Meta: *metaPtr, - }, nil - }, } } diff --git a/command/plugin_exec.go b/command/plugin_exec.go deleted file mode 100644 index 575be14b7..000000000 --- a/command/plugin_exec.go +++ /dev/null @@ -1,66 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/builtinplugins" - "github.com/hashicorp/vault/meta" -) - -type PluginExec struct { - meta.Meta -} - -func (c *PluginExec) Run(args []string) int { - flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nplugin-exec expects one argument: the plugin to execute.")) - return 1 - } - - pluginName := args[0] - - runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName) - if !ok { - c.Ui.Error(fmt.Sprintf( - "No plugin with the name %s found", pluginName)) - return 1 - } - - err := runner() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error running plugin: %s", err)) - return 1 - } - - return 0 -} - -func (c *PluginExec) Synopsis() string { - return "Runs a builtin plugin. Should only be called by vault." -} - -func (c *PluginExec) Help() string { - helpText := ` -Usage: vault plugin-exec type - - Runs a builtin plugin. Should only be called by vault. - - This will execute a plugin for use in a plugable location in vault. If run by - a cli user it will print a message indicating it can not be executed by anyone - other than vault. For supported plugin types see the vault documentation. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/server.go b/command/server.go index 6548aa58c..ef9e3e3a0 100644 --- a/command/server.go +++ b/command/server.go @@ -1,10 +1,8 @@ package command import ( - "crypto/sha256" "encoding/base64" "fmt" - "io" "net" "net/http" "net/url" @@ -133,33 +131,6 @@ func (c *ServerCommand) Run(args []string) int { dev = true } - // Record the vault binary's location and SHA-256 checksum for use in - // builtin plugins. - ex, err := os.Executable() - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error looking up vault binary: %s", err)) - return 1 - } - - file, err := os.Open(ex) - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error loading vault binary: %s", err)) - return 1 - } - defer file.Close() - - hash := sha256.New() - _, err = io.Copy(hash, file) - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error checksumming vault binary: %s", err)) - return 1 - } - - sha256Value := hash.Sum(nil) - // Validation if !dev { switch { @@ -254,23 +225,21 @@ func (c *ServerCommand) Run(args []string) int { } coreConfig := &vault.CoreConfig{ - Physical: backend, - RedirectAddr: config.Storage.RedirectAddr, - HAPhysical: nil, - Seal: seal, - AuditBackends: c.AuditBackends, - CredentialBackends: c.CredentialBackends, - LogicalBackends: c.LogicalBackends, - Logger: c.logger, - DisableCache: config.DisableCache, - DisableMlock: config.DisableMlock, - MaxLeaseTTL: config.MaxLeaseTTL, - DefaultLeaseTTL: config.DefaultLeaseTTL, - ClusterName: config.ClusterName, - CacheSize: config.CacheSize, - PluginDirectory: config.PluginDirectory, - VaultBinaryLocation: ex, - VaultBinarySHA256: sha256Value, + Physical: backend, + RedirectAddr: config.Storage.RedirectAddr, + HAPhysical: nil, + Seal: seal, + AuditBackends: c.AuditBackends, + CredentialBackends: c.CredentialBackends, + LogicalBackends: c.LogicalBackends, + Logger: c.logger, + DisableCache: config.DisableCache, + DisableMlock: config.DisableMlock, + MaxLeaseTTL: config.MaxLeaseTTL, + DefaultLeaseTTL: config.DefaultLeaseTTL, + ClusterName: config.ClusterName, + CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, } if dev { coreConfig.DevToken = devRootTokenID diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index beedbb15b..9c51ae478 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -5,20 +5,22 @@ import ( "github.com/hashicorp/vault/plugins/database/postgresql" ) +type BuiltinFactory func() (interface{}, error) + var BuiltinPlugins *builtinPlugins = &builtinPlugins{ - plugins: map[string]func() error{ - "mysql-database-plugin": mysql.Run, - "postgresql-database-plugin": postgresql.Run, + plugins: map[string]BuiltinFactory{ + "mysql-database-plugin": mysql.New, + "postgresql-database-plugin": postgresql.New, }, } // The list of builtin plugins should not be changed by any other package, so we // store them in an unexported variable in this unexported struct. type builtinPlugins struct { - plugins map[string]func() error + plugins map[string]BuiltinFactory } -func (b *builtinPlugins) Get(name string) (func() error, bool) { +func (b *builtinPlugins) Get(name string) (BuiltinFactory, bool) { f, ok := b.plugins[name] return f, ok } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index bbc5ab99b..95de96a5a 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -40,11 +40,12 @@ type LookWrapper interface { // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { - Name string `json:"name"` - Command string `json:"command"` - Args []string `json:"args"` - Sha256 []byte `json:"sha256"` - Builtin bool `json:"builtin"` + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Sha256 []byte `json:"sha256"` + Builtin bool `json:"builtin"` + BuiltinFactory func() (interface{}, error) `json:"-"` } // Run takes a wrapper instance, and the go-plugin paramaters and executes a diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index ea14a6782..e7e2a8aea 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -23,7 +23,7 @@ type MySQL struct { credsutil.CredentialsProducer } -func New() *MySQL { +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = mySQLTypeName @@ -37,14 +37,17 @@ func New() *MySQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*MySQL)) return nil } diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 2b1f27291..c86f9c2f6 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -66,7 +66,8 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -92,7 +93,8 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) if err != nil { @@ -127,7 +129,8 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) if err != nil { diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index b8449f549..5781b6c3d 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -16,7 +16,8 @@ import ( const postgreSQLTypeName string = "postgres" -func New() *PostgreSQL { +// New implements builtinplugins.BuiltinFactory +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = postgreSQLTypeName @@ -30,14 +31,17 @@ func New() *PostgreSQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instatiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*PostgreSQL)) return nil } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index c7ccc8ee8..79391dc56 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -66,7 +66,9 @@ func TestPostgreSQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -92,7 +94,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -136,7 +139,8 @@ func TestPostgreSQL_RenewUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -176,7 +180,8 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/vault/core.go b/vault/core.go index ef99741bf..01ab49f75 100644 --- a/vault/core.go +++ b/vault/core.go @@ -335,12 +335,6 @@ type Core struct { // pluginDirectory is the location vault will look for plugin binaries pluginDirectory string - // vaultBinaryLocation is used to run builtin plugins in secure mode - vaultBinaryLocation string - - // vaultBinarySHA256 is used to run builtin plugins in secure mode - vaultBinarySHA256 []byte - // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog @@ -389,9 +383,7 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` - PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` - VaultBinaryLocation string `json:"vault_binary_location" structs:"vault_binary_location" mapstructure:"vault_binary_location"` - VaultBinarySHA256 []byte `json:"vault_binary_sha256" structs:"vault_binary_sha256" mapstructure:"vault_binary_sha256"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex @@ -449,8 +441,6 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), - vaultBinaryLocation: conf.VaultBinaryLocation, - vaultBinarySHA256: conf.VaultBinarySHA256, disableMlock: conf.DisableMlock, } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b89224780..598a16fac 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -23,20 +23,16 @@ var ( // to be registered to the catalog before they can be used in backends. Builtin // plugins are automatically detected and included in the catalog. type PluginCatalog struct { - catalogView *BarrierView - directory string - vaultCommand string - vaultSHA256 []byte + catalogView *BarrierView + directory string lock sync.RWMutex } func (c *Core) setupPluginCatalog() error { c.pluginCatalog = &PluginCatalog{ - catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), - directory: c.pluginDirectory, - vaultCommand: c.vaultBinaryLocation, - vaultSHA256: c.vaultBinarySHA256, + catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + directory: c.pluginDirectory, } return nil @@ -64,17 +60,15 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok { - return nil, fmt.Errorf("no plugin found with name: %s", name) + if factory, ok := builtinplugins.BuiltinPlugins.Get(name); ok { + return &pluginutil.PluginRunner{ + Name: name, + Builtin: true, + BuiltinFactory: factory, + }, nil } - return &pluginutil.PluginRunner{ - Name: name, - Command: c.vaultCommand, - Args: []string{"plugin-exec", name}, - Sha256: c.vaultSHA256, - Builtin: true, - }, nil + return nil, fmt.Errorf("no plugin found with name: %s", name) } // Set registers a new external plugin with the catalog, or updates an existing From 30b06b593cd18ed4af0430af55fbf011f2860575 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 21 Apr 2017 09:10:26 -0700 Subject: [PATCH 085/162] Fix tests --- plugins/database/mysql/mysql.go | 1 + vault/logical_system_test.go | 2 -- vault/plugin_catalog_test.go | 4 ---- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index e7e2a8aea..6485aaa86 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -23,6 +23,7 @@ type MySQL struct { credsutil.CredentialsProducer } +// New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = mySQLTypeName diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 0785e07a1..9da4cbdec 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1122,8 +1122,6 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("error: %v", err) } c.pluginCatalog.directory = sym - c.pluginCatalog.vaultCommand = "vault" - c.pluginCatalog.vaultSHA256 = []byte{'1'} req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") resp, err := b.HandleRequest(req) diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index e78e7d963..c33a890cd 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -21,8 +21,6 @@ func TestPluginCatalog_CRUD(t *testing.T) { t.Fatalf("error: %v", err) } core.pluginCatalog.directory = sym - core.pluginCatalog.vaultCommand = "vault" - core.pluginCatalog.vaultSHA256 = []byte{'1'} // Get builtin plugin p, err := core.pluginCatalog.Get("mysql-database-plugin") @@ -99,8 +97,6 @@ func TestPluginCatalog_List(t *testing.T) { t.Fatalf("error: %v", err) } core.pluginCatalog.directory = sym - core.pluginCatalog.vaultCommand = "vault" - core.pluginCatalog.vaultSHA256 = []byte{'1'} // Get builtin plugins and sort them builtinKeys := builtinplugins.BuiltinPlugins.Keys() From 4d0aac963dab1104e91f5b1bf58b630293fa138f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 21 Apr 2017 10:24:34 -0700 Subject: [PATCH 086/162] Fix tests --- vault/logical_system_test.go | 13 +++++++++---- vault/plugin_catalog_test.go | 20 +++++++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 9da4cbdec..e9836946c 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1141,13 +1141,18 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { expectedBuiltin := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", - Command: "vault", - Args: []string{"plugin-exec", "mysql-database-plugin"}, - Sha256: []byte{'1'}, Builtin: true, } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") - if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) { + p := resp.Data["plugin"].(*pluginutil.PluginRunner) + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } + + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil + if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) } diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index c33a890cd..57e864892 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -30,12 +30,15 @@ func TestPluginCatalog_CRUD(t *testing.T) { expectedBuiltin := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", - Command: "vault", - Args: []string{"plugin-exec", "mysql-database-plugin"}, - Sha256: []byte{'1'}, Builtin: true, } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) } @@ -83,6 +86,17 @@ func TestPluginCatalog_CRUD(t *testing.T) { t.Fatalf("unexpected error %v", err) } + expectedBuiltin = &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Builtin: true, + } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) } From c005f8fc91dda2c2744dc8c44423e9102d20b461 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Sun, 23 Apr 2017 09:02:57 +0800 Subject: [PATCH 087/162] Add cassandra plugin --- .../cassandra-database-plugin/main.go | 16 + plugins/database/cassandra/cassandra.go | 145 +++ plugins/database/cassandra/cassandra_test.go | 226 ++++ .../cassandra/test-fixtures/cassandra.yaml | 1146 +++++++++++++++++ plugins/helper/database/connutil/cassandra.go | 7 +- 5 files changed, 1537 insertions(+), 3 deletions(-) create mode 100644 plugins/database/cassandra/cassandra-database-plugin/main.go create mode 100644 plugins/database/cassandra/cassandra.go create mode 100644 plugins/database/cassandra/cassandra_test.go create mode 100644 plugins/database/cassandra/test-fixtures/cassandra.yaml diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go new file mode 100644 index 000000000..79f0e0dbe --- /dev/null +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/cassandra" +) + +func main() { + err := cassandra.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go new file mode 100644 index 000000000..621d6e375 --- /dev/null +++ b/plugins/database/cassandra/cassandra.go @@ -0,0 +1,145 @@ +package cassandra + +import ( + "fmt" + "strings" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const ( + defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultRollbackCQL = `DROP USER '{{username}}';` + cassandraTypeName = "cassandra" +) + +type Cassandra struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *Cassandra { + connProducer := &connutil.CassandraConnectionProducer{} + connProducer.Type = cassandraTypeName + + credsProducer := &credsutil.CassandraCredentialsProducer{} + + dbType := &Cassandra{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MySQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +func (c *Cassandra) Type() (string, error) { + return cassandraTypeName, nil +} + +func (c *Cassandra) getConnection() (*gocql.Session, error) { + session, err := c.Connection() + if err != nil { + return nil, err + } + + return session.(*gocql.Session), nil +} + +// func (c *Cassandra) CreateUser(statements dbplugin.Statements, username, password, expiration string) error { +func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + c.Lock() + defer c.Unlock() + + // Get the connection + session, err := c.getConnection() + if err != nil { + return "", "", err + } + + creationCQL := statements.CreationStatements + if creationCQL == "" { + creationCQL = defaultCreationCQL + } + rollbackCQL := statements.RollbackStatements + if rollbackCQL == "" { + rollbackCQL = defaultRollbackCQL + } + + username, err = c.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = c.GeneratePassword() + if err != nil { + return "", "", err + } + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err = session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + } + return "", "", err + } + } + + return username, password, nil +} + +func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // NOOP + return nil +} + +func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + + session, err := c.getConnection() + if err != nil { + return err + } + + err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() + if err != nil { + return fmt.Errorf("error removing user '%s': %s", username, err) + } + + return nil +} diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go new file mode 100644 index 000000000..b81c32710 --- /dev/null +++ b/plugins/database/cassandra/cassandra_test.go @@ -0,0 +1,226 @@ +package cassandra + +import ( + "os" + "strconv" + "testing" + "time" + + "fmt" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +func prepareCassandraTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("CASSANDRA_HOST") != "" { + return func() {}, os.Getenv("CASSANDRA_HOST") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + cwd, _ := os.Getwd() + cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd) + + ro := &dockertest.RunOptions{ + Repository: "cassandra", + Tag: "latest", + Mounts: []string{cassandraMountPath}, + } + resource, err := pool.RunWithOptions(ro) + if err != nil { + t.Fatalf("Could not start local cassandra docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("localhost:%s", resource.GetPort("9042/tcp")) + port, _ := strconv.Atoi(resource.GetPort("9042/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + clusterConfig := gocql.NewCluster(retURL) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: "cassandra", + Password: "cassandra", + } + clusterConfig.ProtoVersion = 4 + clusterConfig.Port = port + + session, err := clusterConfig.CreateSession() + if err != nil { + return fmt.Errorf("error creating session: %s", err) + } + defer session.Close() + return nil + }); err != nil { + t.Fatalf("Could not connect to cassandra docker container: %s", err) + } + return +} + +func TestCassandra_Initialize(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestCassandra_CreateUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMyCassandra_RenewUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestCassandra_RevokeUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + clusterConfig := gocql.NewCluster(connURL) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: username, + Password: password, + } + clusterConfig.ProtoVersion = 4 + + session, err := clusterConfig.CreateSession() + if err != nil { + return fmt.Errorf("error creating session: %s", err) + } + defer session.Close() + return nil +} + +const testCassandraRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER; +GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};` diff --git a/plugins/database/cassandra/test-fixtures/cassandra.yaml b/plugins/database/cassandra/test-fixtures/cassandra.yaml new file mode 100644 index 000000000..5b12c8cf4 --- /dev/null +++ b/plugins/database/cassandra/test-fixtures/cassandra.yaml @@ -0,0 +1,1146 @@ +# Cassandra storage config YAML + +# NOTE: +# See http://wiki.apache.org/cassandra/StorageConfiguration for +# full explanations of configuration directives +# /NOTE + +# The name of the cluster. This is mainly used to prevent machines in +# one logical cluster from joining another. +cluster_name: 'Test Cluster' + +# This defines the number of tokens randomly assigned to this node on the ring +# The more tokens, relative to other nodes, the larger the proportion of data +# that this node will store. You probably want all nodes to have the same number +# of tokens assuming they have equal hardware capability. +# +# If you leave this unspecified, Cassandra will use the default of 1 token for legacy compatibility, +# and will use the initial_token as described below. +# +# Specifying initial_token will override this setting on the node's initial start, +# on subsequent starts, this setting will apply even if initial token is set. +# +# If you already have a cluster with 1 token per node, and wish to migrate to +# multiple tokens per node, see http://wiki.apache.org/cassandra/Operations +num_tokens: 256 + +# Triggers automatic allocation of num_tokens tokens for this node. The allocation +# algorithm attempts to choose tokens in a way that optimizes replicated load over +# the nodes in the datacenter for the replication strategy used by the specified +# keyspace. +# +# The load assigned to each node will be close to proportional to its number of +# vnodes. +# +# Only supported with the Murmur3Partitioner. +# allocate_tokens_for_keyspace: KEYSPACE + +# initial_token allows you to specify tokens manually. While you can use it with +# vnodes (num_tokens > 1, above) -- in which case you should provide a +# comma-separated list -- it's primarily used when adding nodes to legacy clusters +# that do not have vnodes enabled. +# initial_token: + +# See http://wiki.apache.org/cassandra/HintedHandoff +# May either be "true" or "false" to enable globally +hinted_handoff_enabled: true + +# When hinted_handoff_enabled is true, a black list of data centers that will not +# perform hinted handoff +# hinted_handoff_disabled_datacenters: +# - DC1 +# - DC2 + +# this defines the maximum amount of time a dead host will have hints +# generated. After it has been dead this long, new hints for it will not be +# created until it has been seen alive and gone down again. +max_hint_window_in_ms: 10800000 # 3 hours + +# Maximum throttle in KBs per second, per delivery thread. This will be +# reduced proportionally to the number of nodes in the cluster. (If there +# are two nodes in the cluster, each delivery thread will use the maximum +# rate; if there are three, each will throttle to half of the maximum, +# since we expect two nodes to be delivering hints simultaneously.) +hinted_handoff_throttle_in_kb: 1024 + +# Number of threads with which to deliver hints; +# Consider increasing this number when you have multi-dc deployments, since +# cross-dc handoff tends to be slower +max_hints_delivery_threads: 2 + +# Directory where Cassandra should store hints. +# If not set, the default directory is $CASSANDRA_HOME/data/hints. +# hints_directory: /var/lib/cassandra/hints + +# How often hints should be flushed from the internal buffers to disk. +# Will *not* trigger fsync. +hints_flush_period_in_ms: 10000 + +# Maximum size for a single hints file, in megabytes. +max_hints_file_size_in_mb: 128 + +# Compression to apply to the hint files. If omitted, hints files +# will be written uncompressed. LZ4, Snappy, and Deflate compressors +# are supported. +#hints_compression: +# - class_name: LZ4Compressor +# parameters: +# - + +# Maximum throttle in KBs per second, total. This will be +# reduced proportionally to the number of nodes in the cluster. +batchlog_replay_throttle_in_kb: 1024 + +# Authentication backend, implementing IAuthenticator; used to identify users +# Out of the box, Cassandra provides org.apache.cassandra.auth.{AllowAllAuthenticator, +# PasswordAuthenticator}. +# +# - AllowAllAuthenticator performs no checks - set it to disable authentication. +# - PasswordAuthenticator relies on username/password pairs to authenticate +# users. It keeps usernames and hashed passwords in system_auth.credentials table. +# Please increase system_auth keyspace replication factor if you use this authenticator. +# If using PasswordAuthenticator, CassandraRoleManager must also be used (see below) +authenticator: PasswordAuthenticator + +# Authorization backend, implementing IAuthorizer; used to limit access/provide permissions +# Out of the box, Cassandra provides org.apache.cassandra.auth.{AllowAllAuthorizer, +# CassandraAuthorizer}. +# +# - AllowAllAuthorizer allows any action to any user - set it to disable authorization. +# - CassandraAuthorizer stores permissions in system_auth.permissions table. Please +# increase system_auth keyspace replication factor if you use this authorizer. +authorizer: CassandraAuthorizer + +# Part of the Authentication & Authorization backend, implementing IRoleManager; used +# to maintain grants and memberships between roles. +# Out of the box, Cassandra provides org.apache.cassandra.auth.CassandraRoleManager, +# which stores role information in the system_auth keyspace. Most functions of the +# IRoleManager require an authenticated login, so unless the configured IAuthenticator +# actually implements authentication, most of this functionality will be unavailable. +# +# - CassandraRoleManager stores role data in the system_auth keyspace. Please +# increase system_auth keyspace replication factor if you use this role manager. +role_manager: CassandraRoleManager + +# Validity period for roles cache (fetching granted roles can be an expensive +# operation depending on the role manager, CassandraRoleManager is one example) +# Granted roles are cached for authenticated sessions in AuthenticatedUser and +# after the period specified here, become eligible for (async) reload. +# Defaults to 2000, set to 0 to disable caching entirely. +# Will be disabled automatically for AllowAllAuthenticator. +roles_validity_in_ms: 2000 + +# Refresh interval for roles cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If roles_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as roles_validity_in_ms. +# roles_update_interval_in_ms: 2000 + +# Validity period for permissions cache (fetching permissions can be an +# expensive operation depending on the authorizer, CassandraAuthorizer is +# one example). Defaults to 2000, set to 0 to disable. +# Will be disabled automatically for AllowAllAuthorizer. +permissions_validity_in_ms: 2000 + +# Refresh interval for permissions cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If permissions_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as permissions_validity_in_ms. +# permissions_update_interval_in_ms: 2000 + +# Validity period for credentials cache. This cache is tightly coupled to +# the provided PasswordAuthenticator implementation of IAuthenticator. If +# another IAuthenticator implementation is configured, this cache will not +# be automatically used and so the following settings will have no effect. +# Please note, credentials are cached in their encrypted form, so while +# activating this cache may reduce the number of queries made to the +# underlying table, it may not bring a significant reduction in the +# latency of individual authentication attempts. +# Defaults to 2000, set to 0 to disable credentials caching. +credentials_validity_in_ms: 2000 + +# Refresh interval for credentials cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If credentials_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as credentials_validity_in_ms. +# credentials_update_interval_in_ms: 2000 + +# The partitioner is responsible for distributing groups of rows (by +# partition key) across nodes in the cluster. You should leave this +# alone for new clusters. The partitioner can NOT be changed without +# reloading all data, so when upgrading you should set this to the +# same partitioner you were already using. +# +# Besides Murmur3Partitioner, partitioners included for backwards +# compatibility include RandomPartitioner, ByteOrderedPartitioner, and +# OrderPreservingPartitioner. +# +partitioner: org.apache.cassandra.dht.Murmur3Partitioner + +# Directories where Cassandra should store data on disk. Cassandra +# will spread data evenly across them, subject to the granularity of +# the configured compaction strategy. +# If not set, the default directory is $CASSANDRA_HOME/data/data. +data_file_directories: + - /var/lib/cassandra/data + +# commit log. when running on magnetic HDD, this should be a +# separate spindle than the data directories. +# If not set, the default directory is $CASSANDRA_HOME/data/commitlog. +commitlog_directory: /var/lib/cassandra/commitlog + +# Enable / disable CDC functionality on a per-node basis. This modifies the logic used +# for write path allocation rejection (standard: never reject. cdc: reject Mutation +# containing a CDC-enabled table if at space limit in cdc_raw_directory). +cdc_enabled: false + +# CommitLogSegments are moved to this directory on flush if cdc_enabled: true and the +# segment contains mutations for a CDC-enabled table. This should be placed on a +# separate spindle than the data directories. If not set, the default directory is +# $CASSANDRA_HOME/data/cdc_raw. +# cdc_raw_directory: /var/lib/cassandra/cdc_raw + +# Policy for data disk failures: +# +# die +# shut down gossip and client transports and kill the JVM for any fs errors or +# single-sstable errors, so the node can be replaced. +# +# stop_paranoid +# shut down gossip and client transports even for single-sstable errors, +# kill the JVM for errors during startup. +# +# stop +# shut down gossip and client transports, leaving the node effectively dead, but +# can still be inspected via JMX, kill the JVM for errors during startup. +# +# best_effort +# stop using the failed disk and respond to requests based on +# remaining available sstables. This means you WILL see obsolete +# data at CL.ONE! +# +# ignore +# ignore fatal errors and let requests fail, as in pre-1.2 Cassandra +disk_failure_policy: stop + +# Policy for commit disk failures: +# +# die +# shut down gossip and Thrift and kill the JVM, so the node can be replaced. +# +# stop +# shut down gossip and Thrift, leaving the node effectively dead, but +# can still be inspected via JMX. +# +# stop_commit +# shutdown the commit log, letting writes collect but +# continuing to service reads, as in pre-2.0.5 Cassandra +# +# ignore +# ignore fatal errors and let the batches fail +commit_failure_policy: stop + +# Maximum size of the native protocol prepared statement cache +# +# Valid values are either "auto" (omitting the value) or a value greater 0. +# +# Note that specifying a too large value will result in long running GCs and possbily +# out-of-memory errors. Keep the value at a small fraction of the heap. +# +# If you constantly see "prepared statements discarded in the last minute because +# cache limit reached" messages, the first step is to investigate the root cause +# of these messages and check whether prepared statements are used correctly - +# i.e. use bind markers for variable parts. +# +# Do only change the default value, if you really have more prepared statements than +# fit in the cache. In most cases it is not neccessary to change this value. +# Constantly re-preparing statements is a performance penalty. +# +# Default value ("auto") is 1/256th of the heap or 10MB, whichever is greater +prepared_statements_cache_size_mb: + +# Maximum size of the Thrift prepared statement cache +# +# If you do not use Thrift at all, it is safe to leave this value at "auto". +# +# See description of 'prepared_statements_cache_size_mb' above for more information. +# +# Default value ("auto") is 1/256th of the heap or 10MB, whichever is greater +thrift_prepared_statements_cache_size_mb: + +# Maximum size of the key cache in memory. +# +# Each key cache hit saves 1 seek and each row cache hit saves 2 seeks at the +# minimum, sometimes more. The key cache is fairly tiny for the amount of +# time it saves, so it's worthwhile to use it at large numbers. +# The row cache saves even more time, but must contain the entire row, +# so it is extremely space-intensive. It's best to only use the +# row cache if you have hot rows or static rows. +# +# NOTE: if you reduce the size, you may not get you hottest keys loaded on startup. +# +# Default value is empty to make it "auto" (min(5% of Heap (in MB), 100MB)). Set to 0 to disable key cache. +key_cache_size_in_mb: + +# Duration in seconds after which Cassandra should +# save the key cache. Caches are saved to saved_caches_directory as +# specified in this configuration file. +# +# Saved caches greatly improve cold-start speeds, and is relatively cheap in +# terms of I/O for the key cache. Row cache saving is much more expensive and +# has limited use. +# +# Default is 14400 or 4 hours. +key_cache_save_period: 14400 + +# Number of keys from the key cache to save +# Disabled by default, meaning all keys are going to be saved +# key_cache_keys_to_save: 100 + +# Row cache implementation class name. Available implementations: +# +# org.apache.cassandra.cache.OHCProvider +# Fully off-heap row cache implementation (default). +# +# org.apache.cassandra.cache.SerializingCacheProvider +# This is the row cache implementation availabile +# in previous releases of Cassandra. +# row_cache_class_name: org.apache.cassandra.cache.OHCProvider + +# Maximum size of the row cache in memory. +# Please note that OHC cache implementation requires some additional off-heap memory to manage +# the map structures and some in-flight memory during operations before/after cache entries can be +# accounted against the cache capacity. This overhead is usually small compared to the whole capacity. +# Do not specify more memory that the system can afford in the worst usual situation and leave some +# headroom for OS block level cache. Do never allow your system to swap. +# +# Default value is 0, to disable row caching. +row_cache_size_in_mb: 0 + +# Duration in seconds after which Cassandra should save the row cache. +# Caches are saved to saved_caches_directory as specified in this configuration file. +# +# Saved caches greatly improve cold-start speeds, and is relatively cheap in +# terms of I/O for the key cache. Row cache saving is much more expensive and +# has limited use. +# +# Default is 0 to disable saving the row cache. +row_cache_save_period: 0 + +# Number of keys from the row cache to save. +# Specify 0 (which is the default), meaning all keys are going to be saved +# row_cache_keys_to_save: 100 + +# Maximum size of the counter cache in memory. +# +# Counter cache helps to reduce counter locks' contention for hot counter cells. +# In case of RF = 1 a counter cache hit will cause Cassandra to skip the read before +# write entirely. With RF > 1 a counter cache hit will still help to reduce the duration +# of the lock hold, helping with hot counter cell updates, but will not allow skipping +# the read entirely. Only the local (clock, count) tuple of a counter cell is kept +# in memory, not the whole counter, so it's relatively cheap. +# +# NOTE: if you reduce the size, you may not get you hottest keys loaded on startup. +# +# Default value is empty to make it "auto" (min(2.5% of Heap (in MB), 50MB)). Set to 0 to disable counter cache. +# NOTE: if you perform counter deletes and rely on low gcgs, you should disable the counter cache. +counter_cache_size_in_mb: + +# Duration in seconds after which Cassandra should +# save the counter cache (keys only). Caches are saved to saved_caches_directory as +# specified in this configuration file. +# +# Default is 7200 or 2 hours. +counter_cache_save_period: 7200 + +# Number of keys from the counter cache to save +# Disabled by default, meaning all keys are going to be saved +# counter_cache_keys_to_save: 100 + +# saved caches +# If not set, the default directory is $CASSANDRA_HOME/data/saved_caches. +saved_caches_directory: /var/lib/cassandra/saved_caches + +# commitlog_sync may be either "periodic" or "batch." +# +# When in batch mode, Cassandra won't ack writes until the commit log +# has been fsynced to disk. It will wait +# commitlog_sync_batch_window_in_ms milliseconds between fsyncs. +# This window should be kept short because the writer threads will +# be unable to do extra work while waiting. (You may need to increase +# concurrent_writes for the same reason.) +# +# commitlog_sync: batch +# commitlog_sync_batch_window_in_ms: 2 +# +# the other option is "periodic" where writes may be acked immediately +# and the CommitLog is simply synced every commitlog_sync_period_in_ms +# milliseconds. +commitlog_sync: periodic +commitlog_sync_period_in_ms: 10000 + +# The size of the individual commitlog file segments. A commitlog +# segment may be archived, deleted, or recycled once all the data +# in it (potentially from each columnfamily in the system) has been +# flushed to sstables. +# +# The default size is 32, which is almost always fine, but if you are +# archiving commitlog segments (see commitlog_archiving.properties), +# then you probably want a finer granularity of archiving; 8 or 16 MB +# is reasonable. +# Max mutation size is also configurable via max_mutation_size_in_kb setting in +# cassandra.yaml. The default is half the size commitlog_segment_size_in_mb * 1024. +# +# NOTE: If max_mutation_size_in_kb is set explicitly then commitlog_segment_size_in_mb must +# be set to at least twice the size of max_mutation_size_in_kb / 1024 +# +commitlog_segment_size_in_mb: 32 + +# Compression to apply to the commit log. If omitted, the commit log +# will be written uncompressed. LZ4, Snappy, and Deflate compressors +# are supported. +# commitlog_compression: +# - class_name: LZ4Compressor +# parameters: +# - + +# any class that implements the SeedProvider interface and has a +# constructor that takes a Map of parameters will do. +seed_provider: + # Addresses of hosts that are deemed contact points. + # Cassandra nodes use this list of hosts to find each other and learn + # the topology of the ring. You must change this if you are running + # multiple nodes! + - class_name: org.apache.cassandra.locator.SimpleSeedProvider + parameters: + # seeds is actually a comma-delimited list of addresses. + # Ex: ",," + - seeds: "172.17.0.3" + +# For workloads with more data than can fit in memory, Cassandra's +# bottleneck will be reads that need to fetch data from +# disk. "concurrent_reads" should be set to (16 * number_of_drives) in +# order to allow the operations to enqueue low enough in the stack +# that the OS and drives can reorder them. Same applies to +# "concurrent_counter_writes", since counter writes read the current +# values before incrementing and writing them back. +# +# On the other hand, since writes are almost never IO bound, the ideal +# number of "concurrent_writes" is dependent on the number of cores in +# your system; (8 * number_of_cores) is a good rule of thumb. +concurrent_reads: 32 +concurrent_writes: 32 +concurrent_counter_writes: 32 + +# For materialized view writes, as there is a read involved, so this should +# be limited by the less of concurrent reads or concurrent writes. +concurrent_materialized_view_writes: 32 + +# Maximum memory to use for sstable chunk cache and buffer pooling. +# 32MB of this are reserved for pooling buffers, the rest is used as an +# cache that holds uncompressed sstable chunks. +# Defaults to the smaller of 1/4 of heap or 512MB. This pool is allocated off-heap, +# so is in addition to the memory allocated for heap. The cache also has on-heap +# overhead which is roughly 128 bytes per chunk (i.e. 0.2% of the reserved size +# if the default 64k chunk size is used). +# Memory is only allocated when needed. +# file_cache_size_in_mb: 512 + +# Flag indicating whether to allocate on or off heap when the sstable buffer +# pool is exhausted, that is when it has exceeded the maximum memory +# file_cache_size_in_mb, beyond which it will not cache buffers but allocate on request. + +# buffer_pool_use_heap_if_exhausted: true + +# The strategy for optimizing disk read +# Possible values are: +# ssd (for solid state disks, the default) +# spinning (for spinning disks) +# disk_optimization_strategy: ssd + +# Total permitted memory to use for memtables. Cassandra will stop +# accepting writes when the limit is exceeded until a flush completes, +# and will trigger a flush based on memtable_cleanup_threshold +# If omitted, Cassandra will set both to 1/4 the size of the heap. +# memtable_heap_space_in_mb: 2048 +# memtable_offheap_space_in_mb: 2048 + +# Ratio of occupied non-flushing memtable size to total permitted size +# that will trigger a flush of the largest memtable. Larger mct will +# mean larger flushes and hence less compaction, but also less concurrent +# flush activity which can make it difficult to keep your disks fed +# under heavy write load. +# +# memtable_cleanup_threshold defaults to 1 / (memtable_flush_writers + 1) +# memtable_cleanup_threshold: 0.11 + +# Specify the way Cassandra allocates and manages memtable memory. +# Options are: +# +# heap_buffers +# on heap nio buffers +# +# offheap_buffers +# off heap (direct) nio buffers +# +# offheap_objects +# off heap objects +memtable_allocation_type: heap_buffers + +# Total space to use for commit logs on disk. +# +# If space gets above this value, Cassandra will flush every dirty CF +# in the oldest segment and remove it. So a small total commitlog space +# will tend to cause more flush activity on less-active columnfamilies. +# +# The default value is the smaller of 8192, and 1/4 of the total space +# of the commitlog volume. +# +# commitlog_total_space_in_mb: 8192 + +# This sets the amount of memtable flush writer threads. These will +# be blocked by disk io, and each one will hold a memtable in memory +# while blocked. +# +# memtable_flush_writers defaults to one per data_file_directory. +# +# If your data directories are backed by SSD, you can increase this, but +# avoid having memtable_flush_writers * data_file_directories > number of cores +#memtable_flush_writers: 1 + +# Total space to use for change-data-capture logs on disk. +# +# If space gets above this value, Cassandra will throw WriteTimeoutException +# on Mutations including tables with CDC enabled. A CDCCompactor is responsible +# for parsing the raw CDC logs and deleting them when parsing is completed. +# +# The default value is the min of 4096 mb and 1/8th of the total space +# of the drive where cdc_raw_directory resides. +# cdc_total_space_in_mb: 4096 + +# When we hit our cdc_raw limit and the CDCCompactor is either running behind +# or experiencing backpressure, we check at the following interval to see if any +# new space for cdc-tracked tables has been made available. Default to 250ms +# cdc_free_space_check_interval_ms: 250 + +# A fixed memory pool size in MB for for SSTable index summaries. If left +# empty, this will default to 5% of the heap size. If the memory usage of +# all index summaries exceeds this limit, SSTables with low read rates will +# shrink their index summaries in order to meet this limit. However, this +# is a best-effort process. In extreme conditions Cassandra may need to use +# more than this amount of memory. +index_summary_capacity_in_mb: + +# How frequently index summaries should be resampled. This is done +# periodically to redistribute memory from the fixed-size pool to sstables +# proportional their recent read rates. Setting to -1 will disable this +# process, leaving existing index summaries at their current sampling level. +index_summary_resize_interval_in_minutes: 60 + +# Whether to, when doing sequential writing, fsync() at intervals in +# order to force the operating system to flush the dirty +# buffers. Enable this to avoid sudden dirty buffer flushing from +# impacting read latencies. Almost always a good idea on SSDs; not +# necessarily on platters. +trickle_fsync: false +trickle_fsync_interval_in_kb: 10240 + +# TCP port, for commands and data +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +storage_port: 7000 + +# SSL port, for encrypted communication. Unused unless enabled in +# encryption_options +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +ssl_storage_port: 7001 + +# Address or interface to bind to and tell other Cassandra nodes to connect to. +# You _must_ change this if you want multiple nodes to be able to communicate! +# +# Set listen_address OR listen_interface, not both. +# +# Leaving it blank leaves it up to InetAddress.getLocalHost(). This +# will always do the Right Thing _if_ the node is properly configured +# (hostname, name resolution, etc), and the Right Thing is to use the +# address associated with the hostname (it might not be). +# +# Setting listen_address to 0.0.0.0 is always wrong. +# +listen_address: 172.17.0.3 + +# Set listen_address OR listen_interface, not both. Interfaces must correspond +# to a single address, IP aliasing is not supported. +# listen_interface: eth0 + +# If you choose to specify the interface by name and the interface has an ipv4 and an ipv6 address +# you can specify which should be chosen using listen_interface_prefer_ipv6. If false the first ipv4 +# address will be used. If true the first ipv6 address will be used. Defaults to false preferring +# ipv4. If there is only one address it will be selected regardless of ipv4/ipv6. +# listen_interface_prefer_ipv6: false + +# Address to broadcast to other Cassandra nodes +# Leaving this blank will set it to the same value as listen_address +broadcast_address: 172.17.0.3 + +# When using multiple physical network interfaces, set this +# to true to listen on broadcast_address in addition to +# the listen_address, allowing nodes to communicate in both +# interfaces. +# Ignore this property if the network configuration automatically +# routes between the public and private networks such as EC2. +# listen_on_broadcast_address: false + +# Internode authentication backend, implementing IInternodeAuthenticator; +# used to allow/disallow connections from peer nodes. +# internode_authenticator: org.apache.cassandra.auth.AllowAllInternodeAuthenticator + +# Whether to start the native transport server. +# Please note that the address on which the native transport is bound is the +# same as the rpc_address. The port however is different and specified below. +start_native_transport: true +# port for the CQL native transport to listen for clients on +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +native_transport_port: 9042 +# Enabling native transport encryption in client_encryption_options allows you to either use +# encryption for the standard port or to use a dedicated, additional port along with the unencrypted +# standard native_transport_port. +# Enabling client encryption and keeping native_transport_port_ssl disabled will use encryption +# for native_transport_port. Setting native_transport_port_ssl to a different value +# from native_transport_port will use encryption for native_transport_port_ssl while +# keeping native_transport_port unencrypted. +# native_transport_port_ssl: 9142 +# The maximum threads for handling requests when the native transport is used. +# This is similar to rpc_max_threads though the default differs slightly (and +# there is no native_transport_min_threads, idle threads will always be stopped +# after 30 seconds). +# native_transport_max_threads: 128 +# +# The maximum size of allowed frame. Frame (requests) larger than this will +# be rejected as invalid. The default is 256MB. If you're changing this parameter, +# you may want to adjust max_value_size_in_mb accordingly. +# native_transport_max_frame_size_in_mb: 256 + +# The maximum number of concurrent client connections. +# The default is -1, which means unlimited. +# native_transport_max_concurrent_connections: -1 + +# The maximum number of concurrent client connections per source ip. +# The default is -1, which means unlimited. +# native_transport_max_concurrent_connections_per_ip: -1 + +# Whether to start the thrift rpc server. +start_rpc: false + +# The address or interface to bind the Thrift RPC service and native transport +# server to. +# +# Set rpc_address OR rpc_interface, not both. +# +# Leaving rpc_address blank has the same effect as on listen_address +# (i.e. it will be based on the configured hostname of the node). +# +# Note that unlike listen_address, you can specify 0.0.0.0, but you must also +# set broadcast_rpc_address to a value other than 0.0.0.0. +# +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +rpc_address: 0.0.0.0 + +# Set rpc_address OR rpc_interface, not both. Interfaces must correspond +# to a single address, IP aliasing is not supported. +# rpc_interface: eth1 + +# If you choose to specify the interface by name and the interface has an ipv4 and an ipv6 address +# you can specify which should be chosen using rpc_interface_prefer_ipv6. If false the first ipv4 +# address will be used. If true the first ipv6 address will be used. Defaults to false preferring +# ipv4. If there is only one address it will be selected regardless of ipv4/ipv6. +# rpc_interface_prefer_ipv6: false + +# port for Thrift to listen for clients on +rpc_port: 9160 + +# RPC address to broadcast to drivers and other Cassandra nodes. This cannot +# be set to 0.0.0.0. If left blank, this will be set to the value of +# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must +# be set. +broadcast_rpc_address: 172.17.0.3 + +# enable or disable keepalive on rpc/native connections +rpc_keepalive: true + +# Cassandra provides two out-of-the-box options for the RPC Server: +# +# sync +# One thread per thrift connection. For a very large number of clients, memory +# will be your limiting factor. On a 64 bit JVM, 180KB is the minimum stack size +# per thread, and that will correspond to your use of virtual memory (but physical memory +# may be limited depending on use of stack space). +# +# hsha +# Stands for "half synchronous, half asynchronous." All thrift clients are handled +# asynchronously using a small number of threads that does not vary with the amount +# of thrift clients (and thus scales well to many clients). The rpc requests are still +# synchronous (one thread per active request). If hsha is selected then it is essential +# that rpc_max_threads is changed from the default value of unlimited. +# +# The default is sync because on Windows hsha is about 30% slower. On Linux, +# sync/hsha performance is about the same, with hsha of course using less memory. +# +# Alternatively, can provide your own RPC server by providing the fully-qualified class name +# of an o.a.c.t.TServerFactory that can create an instance of it. +rpc_server_type: sync + +# Uncomment rpc_min|max_thread to set request pool size limits. +# +# Regardless of your choice of RPC server (see above), the number of maximum requests in the +# RPC thread pool dictates how many concurrent requests are possible (but if you are using the sync +# RPC server, it also dictates the number of clients that can be connected at all). +# +# The default is unlimited and thus provides no protection against clients overwhelming the server. You are +# encouraged to set a maximum that makes sense for you in production, but do keep in mind that +# rpc_max_threads represents the maximum number of client requests this server may execute concurrently. +# +# rpc_min_threads: 16 +# rpc_max_threads: 2048 + +# uncomment to set socket buffer sizes on rpc connections +# rpc_send_buff_size_in_bytes: +# rpc_recv_buff_size_in_bytes: + +# Uncomment to set socket buffer size for internode communication +# Note that when setting this, the buffer size is limited by net.core.wmem_max +# and when not setting it it is defined by net.ipv4.tcp_wmem +# See also: +# /proc/sys/net/core/wmem_max +# /proc/sys/net/core/rmem_max +# /proc/sys/net/ipv4/tcp_wmem +# /proc/sys/net/ipv4/tcp_wmem +# and 'man tcp' +# internode_send_buff_size_in_bytes: + +# Uncomment to set socket buffer size for internode communication +# Note that when setting this, the buffer size is limited by net.core.wmem_max +# and when not setting it it is defined by net.ipv4.tcp_wmem +# internode_recv_buff_size_in_bytes: + +# Frame size for thrift (maximum message length). +thrift_framed_transport_size_in_mb: 15 + +# Set to true to have Cassandra create a hard link to each sstable +# flushed or streamed locally in a backups/ subdirectory of the +# keyspace data. Removing these links is the operator's +# responsibility. +incremental_backups: false + +# Whether or not to take a snapshot before each compaction. Be +# careful using this option, since Cassandra won't clean up the +# snapshots for you. Mostly useful if you're paranoid when there +# is a data format change. +snapshot_before_compaction: false + +# Whether or not a snapshot is taken of the data before keyspace truncation +# or dropping of column families. The STRONGLY advised default of true +# should be used to provide data safety. If you set this flag to false, you will +# lose data on truncation or drop. +auto_snapshot: true + +# Granularity of the collation index of rows within a partition. +# Increase if your rows are large, or if you have a very large +# number of rows per partition. The competing goals are these: +# +# - a smaller granularity means more index entries are generated +# and looking up rows withing the partition by collation column +# is faster +# - but, Cassandra will keep the collation index in memory for hot +# rows (as part of the key cache), so a larger granularity means +# you can cache more hot rows +column_index_size_in_kb: 64 + +# Per sstable indexed key cache entries (the collation index in memory +# mentioned above) exceeding this size will not be held on heap. +# This means that only partition information is held on heap and the +# index entries are read from disk. +# +# Note that this size refers to the size of the +# serialized index information and not the size of the partition. +column_index_cache_size_in_kb: 2 + +# Number of simultaneous compactions to allow, NOT including +# validation "compactions" for anti-entropy repair. Simultaneous +# compactions can help preserve read performance in a mixed read/write +# workload, by mitigating the tendency of small sstables to accumulate +# during a single long running compactions. The default is usually +# fine and if you experience problems with compaction running too +# slowly or too fast, you should look at +# compaction_throughput_mb_per_sec first. +# +# concurrent_compactors defaults to the smaller of (number of disks, +# number of cores), with a minimum of 2 and a maximum of 8. +# +# If your data directories are backed by SSD, you should increase this +# to the number of cores. +#concurrent_compactors: 1 + +# Throttles compaction to the given total throughput across the entire +# system. The faster you insert data, the faster you need to compact in +# order to keep the sstable count down, but in general, setting this to +# 16 to 32 times the rate you are inserting data is more than sufficient. +# Setting this to 0 disables throttling. Note that this account for all types +# of compaction, including validation compaction. +compaction_throughput_mb_per_sec: 16 + +# When compacting, the replacement sstable(s) can be opened before they +# are completely written, and used in place of the prior sstables for +# any range that has been written. This helps to smoothly transfer reads +# between the sstables, reducing page cache churn and keeping hot rows hot +sstable_preemptive_open_interval_in_mb: 50 + +# Throttles all outbound streaming file transfers on this node to the +# given total throughput in Mbps. This is necessary because Cassandra does +# mostly sequential IO when streaming data during bootstrap or repair, which +# can lead to saturating the network connection and degrading rpc performance. +# When unset, the default is 200 Mbps or 25 MB/s. +# stream_throughput_outbound_megabits_per_sec: 200 + +# Throttles all streaming file transfer between the datacenters, +# this setting allows users to throttle inter dc stream throughput in addition +# to throttling all network stream traffic as configured with +# stream_throughput_outbound_megabits_per_sec +# When unset, the default is 200 Mbps or 25 MB/s +# inter_dc_stream_throughput_outbound_megabits_per_sec: 200 + +# How long the coordinator should wait for read operations to complete +read_request_timeout_in_ms: 5000 +# How long the coordinator should wait for seq or index scans to complete +range_request_timeout_in_ms: 10000 +# How long the coordinator should wait for writes to complete +write_request_timeout_in_ms: 2000 +# How long the coordinator should wait for counter writes to complete +counter_write_request_timeout_in_ms: 5000 +# How long a coordinator should continue to retry a CAS operation +# that contends with other proposals for the same row +cas_contention_timeout_in_ms: 1000 +# How long the coordinator should wait for truncates to complete +# (This can be much longer, because unless auto_snapshot is disabled +# we need to flush first so we can snapshot before removing the data.) +truncate_request_timeout_in_ms: 60000 +# The default timeout for other, miscellaneous operations +request_timeout_in_ms: 10000 + +# Enable operation timeout information exchange between nodes to accurately +# measure request timeouts. If disabled, replicas will assume that requests +# were forwarded to them instantly by the coordinator, which means that +# under overload conditions we will waste that much extra time processing +# already-timed-out requests. +# +# Warning: before enabling this property make sure to ntp is installed +# and the times are synchronized between the nodes. +cross_node_timeout: false + +# Set socket timeout for streaming operation. +# The stream session is failed if no data/ack is received by any of the participants +# within that period, which means this should also be sufficient to stream a large +# sstable or rebuild table indexes. +# Default value is 86400000ms, which means stale streams timeout after 24 hours. +# A value of zero means stream sockets should never time out. +# streaming_socket_timeout_in_ms: 86400000 + +# phi value that must be reached for a host to be marked down. +# most users should never need to adjust this. +# phi_convict_threshold: 8 + +# endpoint_snitch -- Set this to a class that implements +# IEndpointSnitch. The snitch has two functions: +# +# - it teaches Cassandra enough about your network topology to route +# requests efficiently +# - it allows Cassandra to spread replicas around your cluster to avoid +# correlated failures. It does this by grouping machines into +# "datacenters" and "racks." Cassandra will do its best not to have +# more than one replica on the same "rack" (which may not actually +# be a physical location) +# +# CASSANDRA WILL NOT ALLOW YOU TO SWITCH TO AN INCOMPATIBLE SNITCH +# ONCE DATA IS INSERTED INTO THE CLUSTER. This would cause data loss. +# This means that if you start with the default SimpleSnitch, which +# locates every node on "rack1" in "datacenter1", your only options +# if you need to add another datacenter are GossipingPropertyFileSnitch +# (and the older PFS). From there, if you want to migrate to an +# incompatible snitch like Ec2Snitch you can do it by adding new nodes +# under Ec2Snitch (which will locate them in a new "datacenter") and +# decommissioning the old ones. +# +# Out of the box, Cassandra provides: +# +# SimpleSnitch: +# Treats Strategy order as proximity. This can improve cache +# locality when disabling read repair. Only appropriate for +# single-datacenter deployments. +# +# GossipingPropertyFileSnitch +# This should be your go-to snitch for production use. The rack +# and datacenter for the local node are defined in +# cassandra-rackdc.properties and propagated to other nodes via +# gossip. If cassandra-topology.properties exists, it is used as a +# fallback, allowing migration from the PropertyFileSnitch. +# +# PropertyFileSnitch: +# Proximity is determined by rack and data center, which are +# explicitly configured in cassandra-topology.properties. +# +# Ec2Snitch: +# Appropriate for EC2 deployments in a single Region. Loads Region +# and Availability Zone information from the EC2 API. The Region is +# treated as the datacenter, and the Availability Zone as the rack. +# Only private IPs are used, so this will not work across multiple +# Regions. +# +# Ec2MultiRegionSnitch: +# Uses public IPs as broadcast_address to allow cross-region +# connectivity. (Thus, you should set seed addresses to the public +# IP as well.) You will need to open the storage_port or +# ssl_storage_port on the public IP firewall. (For intra-Region +# traffic, Cassandra will switch to the private IP after +# establishing a connection.) +# +# RackInferringSnitch: +# Proximity is determined by rack and data center, which are +# assumed to correspond to the 3rd and 2nd octet of each node's IP +# address, respectively. Unless this happens to match your +# deployment conventions, this is best used as an example of +# writing a custom Snitch class and is provided in that spirit. +# +# You can use a custom Snitch by setting this to the full class name +# of the snitch, which will be assumed to be on your classpath. +endpoint_snitch: SimpleSnitch + +# controls how often to perform the more expensive part of host score +# calculation +dynamic_snitch_update_interval_in_ms: 100 +# controls how often to reset all host scores, allowing a bad host to +# possibly recover +dynamic_snitch_reset_interval_in_ms: 600000 +# if set greater than zero and read_repair_chance is < 1.0, this will allow +# 'pinning' of replicas to hosts in order to increase cache capacity. +# The badness threshold will control how much worse the pinned host has to be +# before the dynamic snitch will prefer other replicas over it. This is +# expressed as a double which represents a percentage. Thus, a value of +# 0.2 means Cassandra would continue to prefer the static snitch values +# until the pinned host was 20% worse than the fastest. +dynamic_snitch_badness_threshold: 0.1 + +# request_scheduler -- Set this to a class that implements +# RequestScheduler, which will schedule incoming client requests +# according to the specific policy. This is useful for multi-tenancy +# with a single Cassandra cluster. +# NOTE: This is specifically for requests from the client and does +# not affect inter node communication. +# org.apache.cassandra.scheduler.NoScheduler - No scheduling takes place +# org.apache.cassandra.scheduler.RoundRobinScheduler - Round robin of +# client requests to a node with a separate queue for each +# request_scheduler_id. The scheduler is further customized by +# request_scheduler_options as described below. +request_scheduler: org.apache.cassandra.scheduler.NoScheduler + +# Scheduler Options vary based on the type of scheduler +# +# NoScheduler +# Has no options +# +# RoundRobin +# throttle_limit +# The throttle_limit is the number of in-flight +# requests per client. Requests beyond +# that limit are queued up until +# running requests can complete. +# The value of 80 here is twice the number of +# concurrent_reads + concurrent_writes. +# default_weight +# default_weight is optional and allows for +# overriding the default which is 1. +# weights +# Weights are optional and will default to 1 or the +# overridden default_weight. The weight translates into how +# many requests are handled during each turn of the +# RoundRobin, based on the scheduler id. +# +# request_scheduler_options: +# throttle_limit: 80 +# default_weight: 5 +# weights: +# Keyspace1: 1 +# Keyspace2: 5 + +# request_scheduler_id -- An identifier based on which to perform +# the request scheduling. Currently the only valid option is keyspace. +# request_scheduler_id: keyspace + +# Enable or disable inter-node encryption +# JVM defaults for supported SSL socket protocols and cipher suites can +# be replaced using custom encryption options. This is not recommended +# unless you have policies in place that dictate certain settings, or +# need to disable vulnerable ciphers or protocols in case the JVM cannot +# be updated. +# FIPS compliant settings can be configured at JVM level and should not +# involve changing encryption settings here: +# https://docs.oracle.com/javase/8/docs/technotes/guides/security/jsse/FIPS.html +# *NOTE* No custom encryption options are enabled at the moment +# The available internode options are : all, none, dc, rack +# +# If set to dc cassandra will encrypt the traffic between the DCs +# If set to rack cassandra will encrypt the traffic between the racks +# +# The passwords used in these options must match the passwords used when generating +# the keystore and truststore. For instructions on generating these files, see: +# http://download.oracle.com/javase/6/docs/technotes/guides/security/jsse/JSSERefGuide.html#CreateKeystore +# +server_encryption_options: + internode_encryption: none + keystore: conf/.keystore + keystore_password: cassandra + truststore: conf/.truststore + truststore_password: cassandra + # More advanced defaults below: + # protocol: TLS + # algorithm: SunX509 + # store_type: JKS + # cipher_suites: [TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA,TLS_DHE_RSA_WITH_AES_128_CBC_SHA,TLS_DHE_RSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA] + # require_client_auth: false + # require_endpoint_verification: false + +# enable or disable client/server encryption. +client_encryption_options: + enabled: false + # If enabled and optional is set to true encrypted and unencrypted connections are handled. + optional: false + keystore: conf/.keystore + keystore_password: cassandra + # require_client_auth: false + # Set trustore and truststore_password if require_client_auth is true + # truststore: conf/.truststore + # truststore_password: cassandra + # More advanced defaults below: + # protocol: TLS + # algorithm: SunX509 + # store_type: JKS + # cipher_suites: [TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA,TLS_DHE_RSA_WITH_AES_128_CBC_SHA,TLS_DHE_RSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA] + +# internode_compression controls whether traffic between nodes is +# compressed. +# Can be: +# +# all +# all traffic is compressed +# +# dc +# traffic between different datacenters is compressed +# +# none +# nothing is compressed. +internode_compression: dc + +# Enable or disable tcp_nodelay for inter-dc communication. +# Disabling it will result in larger (but fewer) network packets being sent, +# reducing overhead from the TCP protocol itself, at the cost of increasing +# latency if you block for cross-datacenter responses. +inter_dc_tcp_nodelay: false + +# TTL for different trace types used during logging of the repair process. +tracetype_query_ttl: 86400 +tracetype_repair_ttl: 604800 + +# By default, Cassandra logs GC Pauses greater than 200 ms at INFO level +# This threshold can be adjusted to minimize logging if necessary +# gc_log_threshold_in_ms: 200 + +# If unset, all GC Pauses greater than gc_log_threshold_in_ms will log at +# INFO level +# UDFs (user defined functions) are disabled by default. +# As of Cassandra 3.0 there is a sandbox in place that should prevent execution of evil code. +enable_user_defined_functions: false + +# Enables scripted UDFs (JavaScript UDFs). +# Java UDFs are always enabled, if enable_user_defined_functions is true. +# Enable this option to be able to use UDFs with "language javascript" or any custom JSR-223 provider. +# This option has no effect, if enable_user_defined_functions is false. +enable_scripted_user_defined_functions: false + +# The default Windows kernel timer and scheduling resolution is 15.6ms for power conservation. +# Lowering this value on Windows can provide much tighter latency and better throughput, however +# some virtualized environments may see a negative performance impact from changing this setting +# below their system default. The sysinternals 'clockres' tool can confirm your system's default +# setting. +windows_timer_interval: 1 + + +# Enables encrypting data at-rest (on disk). Different key providers can be plugged in, but the default reads from +# a JCE-style keystore. A single keystore can hold multiple keys, but the one referenced by +# the "key_alias" is the only key that will be used for encrypt opertaions; previously used keys +# can still (and should!) be in the keystore and will be used on decrypt operations +# (to handle the case of key rotation). +# +# It is strongly recommended to download and install Java Cryptography Extension (JCE) +# Unlimited Strength Jurisdiction Policy Files for your version of the JDK. +# (current link: http://www.oracle.com/technetwork/java/javase/downloads/jce8-download-2133166.html) +# +# Currently, only the following file types are supported for transparent data encryption, although +# more are coming in future cassandra releases: commitlog, hints +transparent_data_encryption_options: + enabled: false + chunk_length_kb: 64 + cipher: AES/CBC/PKCS5Padding + key_alias: testing:1 + # CBC IV length for AES needs to be 16 bytes (which is also the default size) + # iv_length: 16 + key_provider: + - class_name: org.apache.cassandra.security.JKSKeyProvider + parameters: + - keystore: conf/.keystore + keystore_password: cassandra + store_type: JCEKS + key_password: cassandra + + +##################### +# SAFETY THRESHOLDS # +##################### + +# When executing a scan, within or across a partition, we need to keep the +# tombstones seen in memory so we can return them to the coordinator, which +# will use them to make sure other replicas also know about the deleted rows. +# With workloads that generate a lot of tombstones, this can cause performance +# problems and even exaust the server heap. +# (http://www.datastax.com/dev/blog/cassandra-anti-patterns-queues-and-queue-like-datasets) +# Adjust the thresholds here if you understand the dangers and want to +# scan more tombstones anyway. These thresholds may also be adjusted at runtime +# using the StorageService mbean. +tombstone_warn_threshold: 1000 +tombstone_failure_threshold: 100000 + +# Log WARN on any batch size exceeding this value. 5kb per batch by default. +# Caution should be taken on increasing the size of this threshold as it can lead to node instability. +batch_size_warn_threshold_in_kb: 5 + +# Fail any batch exceeding this value. 50kb (10x warn threshold) by default. +batch_size_fail_threshold_in_kb: 50 + +# Log WARN on any batches not of type LOGGED than span across more partitions than this limit +unlogged_batch_across_partitions_warn_threshold: 10 + +# Log a warning when compacting partitions larger than this value +compaction_large_partition_warning_threshold_mb: 100 + +# GC Pauses greater than gc_warn_threshold_in_ms will be logged at WARN level +# Adjust the threshold based on your application throughput requirement +# By default, Cassandra logs GC Pauses greater than 200 ms at INFO level +gc_warn_threshold_in_ms: 1000 + +# Maximum size of any value in SSTables. Safety measure to detect SSTable corruption +# early. Any value size larger than this threshold will result into marking an SSTable +# as corrupted. +# max_value_size_in_mb: 256 diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 028c6814f..1babc3cbd 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -31,6 +31,7 @@ type CassandraConnectionProducer struct { Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` Initialized bool + Type string session *gocql.Session sync.Mutex } @@ -46,14 +47,14 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve c.Initialized = true if verifyConnection { - if _, err := c.connection(); err != nil { + if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) } } return nil } -func (c *CassandraConnectionProducer) connection() (interface{}, error) { +func (c *CassandraConnectionProducer) Connection() (interface{}, error) { if !c.Initialized { return nil, errNotInitialized } @@ -106,7 +107,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { var tlsConfig *tls.Config if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + return nil, fmt.Errorf("found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} From b7e69d0cb62188801201e155e9c5f14b778b5364 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Sun, 23 Apr 2017 00:04:05 -0400 Subject: [PATCH 088/162] Remove commented old method signature --- plugins/database/cassandra/cassandra.go | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 621d6e375..15df0352e 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -60,7 +60,6 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -// func (c *Cassandra) CreateUser(statements dbplugin.Statements, username, password, expiration string) error { func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() From 6c8239ba03f4b8311c63dee66f073f97152e9bdd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 10:30:33 -0700 Subject: [PATCH 089/162] Update the builtin keys; move catalog to core; protect against unset plugin directory --- command/server.go | 19 ---------------- helper/builtinplugins/builtin.go | 24 +++++++-------------- helper/pluginutil/runner.go | 4 ++-- vault/logical_system.go | 13 +++++++---- vault/logical_system_test.go | 6 +++--- vault/plugin_catalog.go | 37 +++++++++++++++++++------------- vault/plugin_catalog_test.go | 6 +++--- 7 files changed, 47 insertions(+), 62 deletions(-) diff --git a/command/server.go b/command/server.go index ef9e3e3a0..9697c1dc8 100644 --- a/command/server.go +++ b/command/server.go @@ -8,7 +8,6 @@ import ( "net/url" "os" "os/signal" - "path/filepath" "runtime" "sort" "strconv" @@ -21,7 +20,6 @@ import ( colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" - homedir "github.com/mitchellh/go-homedir" "google.golang.org/grpc/grpclog" @@ -245,23 +243,6 @@ func (c *ServerCommand) Run(args []string) int { coreConfig.DevToken = devRootTokenID } - if config.PluginDirectory == "" { - homePath, err := homedir.Dir() - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error getting user's home directory: %v", err)) - return 1 - } - coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") - err = os.Mkdir(coreConfig.PluginDirectory, 0700) - if err != nil && !os.IsExist(err) { - c.Ui.Output(fmt.Sprintf( - "Error making default plugin directory: %v", err)) - return 1 - } - - } - var disableClustering bool // Initialize the separate HA storage backend, if it exists diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 9c51ae478..b61a51710 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -7,29 +7,21 @@ import ( type BuiltinFactory func() (interface{}, error) -var BuiltinPlugins *builtinPlugins = &builtinPlugins{ - plugins: map[string]BuiltinFactory{ - "mysql-database-plugin": mysql.New, - "postgresql-database-plugin": postgresql.New, - }, +var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ + "mysql-database-plugin": mysql.New, + "postgresql-database-plugin": postgresql.New, } -// The list of builtin plugins should not be changed by any other package, so we -// store them in an unexported variable in this unexported struct. -type builtinPlugins struct { - plugins map[string]BuiltinFactory -} - -func (b *builtinPlugins) Get(name string) (BuiltinFactory, bool) { - f, ok := b.plugins[name] +func Get(name string) (BuiltinFactory, bool) { + f, ok := plugins[name] return f, ok } -func (b *builtinPlugins) Keys() []string { - keys := make([]string, len(b.plugins)) +func Keys() []string { + keys := make([]string, len(plugins)) i := 0 - for k := range b.plugins { + for k := range plugins { keys[i] = k i++ } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 95de96a5a..9963704e5 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -12,8 +12,8 @@ import ( ) var ( - // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the - // plugin. + // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for + // enabling mlock PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) diff --git a/vault/logical_system.go b/vault/logical_system.go index f43de9ef6..cd7113aa3 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -710,13 +710,19 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ - Type: framework.TypeString, + Type: framework.TypeString, + Description: "The name of the plugin", }, "sha_256": &framework.FieldSchema{ Type: framework.TypeString, + Description: `The SHA256 sum of the executable used in the + command field. This should be HEX encoded.`, }, "command": &framework.FieldSchema{ Type: framework.TypeString, + Description: `The command used to start the plugin. The + executable defined in this command must exist in vault's + plugin directory.`, }, }, @@ -767,8 +773,7 @@ func (b *SystemBackend) handlePluginCatalogList(req *logical.Request, d *framewo return nil, err } - resp := logical.ListResponse(plugins) - return resp, nil + return logical.ListResponse(plugins), nil } func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -2524,7 +2529,7 @@ This path responds to the following HTTP methods. `Configures the plugins known to vault`, ` This path responds to the following HTTP methods. - GET / + LIST / Returns a list of names of configured plugins. GET / diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index e9836946c..ea940d540 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1129,8 +1129,8 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("err: %v", err) } - if len(resp.Data["keys"].([]string)) != len(builtinplugins.BuiltinPlugins.Keys()) { - t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.BuiltinPlugins.Keys())) + if len(resp.Data["keys"].([]string)) != len(builtinplugins.Keys()) { + t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys())) } req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") @@ -1143,7 +1143,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") p := resp.Data["plugin"].(*pluginutil.PluginRunner) if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 598a16fac..5d88873b3 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -16,7 +16,8 @@ import ( ) var ( - pluginCatalogPrefix = "plugin-catalog/" + pluginCatalogPath = "core/plugin-catalog/" + ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") ) // PluginCatalog keeps a record of plugins known to vault. External plugins need @@ -31,7 +32,7 @@ type PluginCatalog struct { func (c *Core) setupPluginCatalog() error { c.pluginCatalog = &PluginCatalog{ - catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + catalogView: NewBarrierView(c.barrier, pluginCatalogPath), directory: c.pluginDirectory, } @@ -45,22 +46,24 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { c.lock.RLock() defer c.lock.RUnlock() - // Look for external plugins in the barrier - out, err := c.catalogView.Get(name) - if err != nil { - return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) - } - if out != nil { - entry := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + // If the directory isn't set only look for builtin plugins. + if c.directory != "" { + // Look for external plugins in the barrier + out, err := c.catalogView.Get(name) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) } + if out != nil { + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } - return entry, nil + return entry, nil + } } - // Look for builtin plugins - if factory, ok := builtinplugins.BuiltinPlugins.Get(name); ok { + if factory, ok := builtinplugins.Get(name); ok { return &pluginutil.PluginRunner{ Name: name, Builtin: true, @@ -74,6 +77,10 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { // Set registers a new external plugin with the catalog, or updates an existing // external plugin. It takes the name, command and SHA256 of the plugin. func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + if c.directory == "" { + return ErrDirectoryNotConfigured + } + c.lock.Lock() defer c.lock.Unlock() @@ -143,7 +150,7 @@ func (c *PluginCatalog) List() ([]string, error) { } // Get the keys for builtin plugins - builtinKeys := builtinplugins.BuiltinPlugins.Keys() + builtinKeys := builtinplugins.Keys() // Use a map to unique the two lists mapKeys := make(map[string]bool) diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 57e864892..6cfacda7e 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -32,7 +32,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -90,7 +90,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -113,7 +113,7 @@ func TestPluginCatalog_List(t *testing.T) { core.pluginCatalog.directory = sym // Get builtin plugins and sort them - builtinKeys := builtinplugins.BuiltinPlugins.Keys() + builtinKeys := builtinplugins.Keys() sort.Strings(builtinKeys) // List only builtin plugins From c4e2ad74c57b653bb1d93ac62cfa673644fd6931 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 11:35:32 -0700 Subject: [PATCH 090/162] Update path for the plugin catalog in logical system --- vault/logical_system.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vault/logical_system.go b/vault/logical_system.go index cd7113aa3..843483449 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,7 +63,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog/*", + "plugins/catalog/*", }, Unauthenticated: []string{ @@ -694,7 +694,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, &framework.Path{ - Pattern: "plugin-catalog/$", + Pattern: "plugins/catalog/$", Fields: map[string]*framework.FieldSchema{}, @@ -706,7 +706,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, &framework.Path{ - Pattern: "plugin-catalog/(?P.+)", + Pattern: "plugins/catalog/(?P.+)", Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ @@ -2525,7 +2525,7 @@ This path responds to the following HTTP methods. "Lists the headers configured to be audited.", `Returns a list of headers that have been configured to be audited.`, }, - "plugin-catalog": { + "plugins/catalog": { `Configures the plugins known to vault`, ` This path responds to the following HTTP methods. From 657d433330debd76ccdf769f9cab918280abffeb Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:15:01 -0700 Subject: [PATCH 091/162] Update the ResponseWrapData function to return a wrapping.ResponseWrapInfo object --- audit/hashstructure.go | 3 ++- audit/hashstructure_test.go | 7 ++++--- helper/pluginutil/mlock.go | 23 +++++++++++++++++++++++ helper/pluginutil/runner.go | 21 ++------------------- helper/pluginutil/tls.go | 7 +++++-- helper/wrapping/wrapinfo.go | 23 +++++++++++++++++++++++ logical/response.go | 26 +++----------------------- logical/system_view.go | 7 ++++--- vault/dynamic_system_view.go | 11 ++++++----- vault/logical_system.go | 3 ++- vault/request_handling.go | 5 +++-- 11 files changed, 77 insertions(+), 59 deletions(-) create mode 100644 helper/pluginutil/mlock.go create mode 100644 helper/wrapping/wrapinfo.go diff --git a/audit/hashstructure.go b/audit/hashstructure.go index 8d0fd7c6c..ea0899ee9 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" "github.com/mitchellh/reflectwalk" @@ -84,7 +85,7 @@ func Hash(salter *salt.Salt, raw interface{}) error { s.Data = data.(map[string]interface{}) - case *logical.ResponseWrapInfo: + case *wrapping.ResponseWrapInfo: if s == nil { return nil } diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 5fefa0fa9..6916d0d3a 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" ) @@ -69,7 +70,7 @@ func TestCopy_response(t *testing.T) { Data: map[string]interface{}{ "foo": "bar", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "foo", CreationTime: time.Now(), @@ -140,7 +141,7 @@ func TestHash(t *testing.T) { Data: map[string]interface{}{ "foo": "bar", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "bar", CreationTime: now, @@ -151,7 +152,7 @@ func TestHash(t *testing.T) { Data: map[string]interface{}{ "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", CreationTime: now, diff --git a/helper/pluginutil/mlock.go b/helper/pluginutil/mlock.go new file mode 100644 index 000000000..dd9115a89 --- /dev/null +++ b/helper/pluginutil/mlock.go @@ -0,0 +1,23 @@ +package pluginutil + +import ( + "os" + + "github.com/hashicorp/vault/helper/mlock" +) + +var ( + // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for + // enabling mlock + PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" +) + +// OptionallyEnableMlock determines if mlock should be called, and if so enables +// mlock. +func OptionallyEnableMlock() error { + if os.Getenv(PluginMlockEnabled) == "true" { + return mlock.LockMemory() + } + + return nil +} diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 9963704e5..539c3b448 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -3,18 +3,11 @@ package pluginutil import ( "crypto/sha256" "fmt" - "os" "os/exec" "time" plugin "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/mlock" -) - -var ( - // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for - // enabling mlock - PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" + "github.com/hashicorp/vault/helper/wrapping" ) // Looker defines the plugin Lookup function that looks into the plugin catalog @@ -27,7 +20,7 @@ type Looker interface { // metadata needed to run a plugin process. This includes looking up Mlock // configuration and wrapping data in a respose wrapped token. type Wrapper interface { - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockDisabled() bool } @@ -97,13 +90,3 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } - -// OptionallyEnableMlock determines if mlock should be called, and if so enables -// mlock. -func OptionallyEnableMlock() error { - if os.Getenv(PluginMlockEnabled) == "true" { - return mlock.LockMemory() - } - - return nil -} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index d4c0946e4..ee0c54d89 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -103,12 +103,15 @@ func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (str return "", err } - wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + wrapInfo, err := sys.ResponseWrapData(map[string]interface{}{ "ServerCert": certBytes, "ServerKey": rawKey, }, time.Second*10, true) + if err != nil { + return "", err + } - return wrapToken, err + return wrapInfo.Token, nil } // VaultPluginTLSProvider is run inside a plugin and retrives the response diff --git a/helper/wrapping/wrapinfo.go b/helper/wrapping/wrapinfo.go new file mode 100644 index 000000000..a27219b8a --- /dev/null +++ b/helper/wrapping/wrapinfo.go @@ -0,0 +1,23 @@ +package wrapping + +import "time" + +type ResponseWrapInfo struct { + // Setting to non-zero specifies that the response should be wrapped. + // Specifies the desired TTL of the wrapping token. + TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` + + // The token containing the wrapped response + Token string `json:"token" structs:"token" mapstructure:"token"` + + // The creation time. This can be used with the TTL to figure out an + // expected expiration. + CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"` + + // If the contained response is the output of a token creation call, the + // created token's accessor will be accessible here + WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"` + + // The format to use. This doesn't get returned, it's only internal. + Format string `json:"format" structs:"format" mapstructure:"format"` +} diff --git a/logical/response.go b/logical/response.go index ee6bfe1e2..2a4646a2c 100644 --- a/logical/response.go +++ b/logical/response.go @@ -4,8 +4,8 @@ import ( "errors" "fmt" "reflect" - "time" + "github.com/hashicorp/vault/helper/wrapping" "github.com/mitchellh/copystructure" ) @@ -28,26 +28,6 @@ const ( HTTPStatusCode = "http_status_code" ) -type ResponseWrapInfo struct { - // Setting to non-zero specifies that the response should be wrapped. - // Specifies the desired TTL of the wrapping token. - TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` - - // The token containing the wrapped response - Token string `json:"token" structs:"token" mapstructure:"token"` - - // The creation time. This can be used with the TTL to figure out an - // expected expiration. - CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"` - - // If the contained response is the output of a token creation call, the - // created token's accessor will be accessible here - WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"` - - // The format to use. This doesn't get returned, it's only internal. - Format string `json:"format" structs:"format" mapstructure:"format"` -} - // Response is a struct that stores the response of a request. // It is used to abstract the details of the higher level request protocol. type Response struct { @@ -78,7 +58,7 @@ type Response struct { warnings []string `json:"warnings" structs:"warnings" mapstructure:"warnings"` // Information for wrapping the response in a cubbyhole - WrapInfo *ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"` + WrapInfo *wrapping.ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"` } func init() { @@ -123,7 +103,7 @@ func init() { if err != nil { return nil, fmt.Errorf("error copying WrapInfo: %v", err) } - ret.WrapInfo = retWrapInfo.(*ResponseWrapInfo) + ret.WrapInfo = retWrapInfo.(*wrapping.ResponseWrapInfo) } return &ret, nil diff --git a/logical/system_view.go b/logical/system_view.go index b6ab14b1f..e13b63f28 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -6,6 +6,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/wrapping" ) // SystemView exposes system configuration information in a safe way @@ -42,7 +43,7 @@ type SystemView interface { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) // LookupPlugin looks into the plugin catalog for a plugin with the given // name. Returns a PluginRunner or an error if a plugin can not be found. @@ -87,8 +88,8 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } -func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { - return "", errors.New("ResponseWrapData is not implemented in StaticSystemView") +func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { + return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView") } func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index eb99f29c6..9302bfbc1 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" ) @@ -91,14 +92,14 @@ func (d dynamicSystemView) ReplicationState() consts.ReplicationState { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. -func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { +func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { req := &logical.Request{ Operation: logical.CreateOperation, - Path: "sys/init", + Path: "sys/wrapping/wrap", } resp := &logical.Response{ - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: ttl, }, Data: data, @@ -110,10 +111,10 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim _, err := d.core.wrapInCubbyhole(req, resp) if err != nil { - return "", err + return nil, err } - return resp.WrapInfo.Token, nil + return resp.WrapInfo, nil } // LookupPlugin looks for a plugin with the given name in the plugin catalog. It diff --git a/vault/logical_system.go b/vault/logical_system.go index 843483449..109100090 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/parseutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/mitchellh/mapstructure" @@ -2075,7 +2076,7 @@ func (b *SystemBackend) handleWrappingRewrap( Data: map[string]interface{}{ "response": response, }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: time.Duration(creationTTL), }, }, nil diff --git a/vault/request_handling.go b/vault/request_handling.go index ad37b5aee..1326ef518 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" ) @@ -216,7 +217,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r } if wrapTTL > 0 { - resp.WrapInfo = &logical.ResponseWrapInfo{ + resp.WrapInfo = &wrapping.ResponseWrapInfo{ TTL: wrapTTL, Format: wrapFormat, } @@ -361,7 +362,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log } if wrapTTL > 0 { - resp.WrapInfo = &logical.ResponseWrapInfo{ + resp.WrapInfo = &wrapping.ResponseWrapInfo{ TTL: wrapTTL, Format: wrapFormat, } From ce9688ce8c22284c1b1344a7a3d94cfdee8560a0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:21:49 -0700 Subject: [PATCH 092/162] Change MlockDisabled to MlockEnabled --- helper/pluginutil/runner.go | 11 ++++------- logical/system_view.go | 11 ++++++----- vault/core.go | 4 ++-- vault/dynamic_system_view.go | 6 +++--- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 539c3b448..6a8df7385 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -21,7 +21,7 @@ type Looker interface { // configuration and wrapping data in a respose wrapped token. type Wrapper interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) - MlockDisabled() bool + MlockEnabled() bool } // LookWrapper defines the functions for both Looker and Wrapper @@ -63,17 +63,14 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return nil, err } - mlock := "true" - if wrapper.MlockDisabled() { - mlock = "false" - } - cmd := exec.Command(r.Command, r.Args...) cmd.Env = append(cmd.Env, env...) // Add the response wrap token to the ENV of the plugin cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) // Add the mlock setting to the ENV of the plugin - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, mlock)) + if wrapper.MlockEnabled() { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) + } secureConfig := &plugin.SecureConfig{ Checksum: r.Sha256, diff --git a/logical/system_view.go b/logical/system_view.go index e13b63f28..175edc0f9 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -49,8 +49,9 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) - // MlockDisabled returns the configuration setting for DisableMlock. - MlockDisabled() bool + // MlockEnabled returns the configuration setting for Enableing mlock on + // plugins. + MlockEnabled() bool } type StaticSystemView struct { @@ -60,7 +61,7 @@ type StaticSystemView struct { TaintedVal bool CachingDisabledVal bool Primary bool - DisableMlock bool + EnableMlock bool ReplicationStateVal consts.ReplicationState } @@ -96,6 +97,6 @@ func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, e return nil, errors.New("LookupPlugin is not implemented in StaticSystemView") } -func (d StaticSystemView) MlockDisabled() bool { - return d.DisableMlock +func (d StaticSystemView) MlockEnabled() bool { + return d.EnableMlock } diff --git a/vault/core.go b/vault/core.go index 01ab49f75..260ce096e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -338,7 +338,7 @@ type Core struct { // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog - disableMlock bool + enableMlock bool } // CoreConfig is used to parameterize a core @@ -441,7 +441,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), - disableMlock: conf.DisableMlock, + enableMlock: !conf.DisableMlock, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 9302bfbc1..edac20140 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -123,7 +123,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, return d.core.pluginCatalog.Get(name) } -// MlockDisabled returns the configuration setting "DisableMlock". -func (d dynamicSystemView) MlockDisabled() bool { - return d.core.disableMlock +// MlockEnabled returns the configuration setting for enabling mlock on plugins. +func (d dynamicSystemView) MlockEnabled() bool { + return d.core.enableMlock } From 5ff317eb8d5d5f71881d60e3d2ed3ea8dbea8537 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:47:40 -0700 Subject: [PATCH 093/162] Update root paths test --- vault/logical_system_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index ea940d540..9aae06778 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -32,7 +32,7 @@ func TestSystemBackend_RootPaths(t *testing.T) { "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog/*", + "plugins/catalog/*", } b := testSystemBackend(t) From 039bc19dd86204f8f03c3ced3076fa80d8c16dbd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 13:48:46 -0700 Subject: [PATCH 094/162] Fix test --- vault/logical_system_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 9aae06778..7bedf7cd6 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1123,7 +1123,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } c.pluginCatalog.directory = sym - req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") + req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/") resp, err := b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1133,7 +1133,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys())) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/mysql-database-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1164,7 +1164,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { defer file.Close() command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) - req = logical.TestRequest(t, logical.UpdateOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/test-plugin") req.Data["sha_256"] = hex.EncodeToString([]byte{'1'}) req.Data["command"] = command resp, err = b.HandleRequest(req) @@ -1172,7 +1172,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("err: %v", err) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1190,13 +1190,13 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } // Delete plugin - req = logical.TestRequest(t, logical.DeleteOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.DeleteOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err == nil { t.Fatalf("expected error, plugin not deleted correctly") From 378ae98809077f7d51e02f064f881e9890707ab5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 13:59:12 -0700 Subject: [PATCH 095/162] s/DatabaseType/Database/ --- builtin/logical/database/backend.go | 8 ++++---- builtin/logical/database/dbplugin/client.go | 4 ++-- .../database/dbplugin/databasemiddleware.go | 8 ++++---- builtin/logical/database/dbplugin/plugin.go | 14 +++++++------- builtin/logical/database/dbplugin/server.go | 10 +++++----- builtin/logical/database/path_config_connection.go | 2 +- plugins/database/mssql/mssql.go | 2 +- plugins/helper/database/connutil/connutil.go | 2 +- plugins/helper/database/credsutil/credsutil.go | 2 +- 9 files changed, 26 insertions(+), 26 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e57fa19c1..7d6ffe9c9 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -41,12 +41,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbplugin.DatabaseType) + b.connections = make(map[string]dbplugin.Database) return &b } type databaseBackend struct { - connections map[string]dbplugin.DatabaseType + connections map[string]dbplugin.Database logger log.Logger *framework.Backend @@ -62,13 +62,13 @@ func (b *databaseBackend) closeAllDBs() { db.Close() } - b.connections = nil + b.connections = make(map[string]dbplugin.Database) } // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 93db86595..8cfc3aad0 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -29,7 +29,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (Database, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), @@ -65,7 +65,7 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn // ---- RPC client domain ---- -// databasePluginRPCClient implements DatabaseType and is used on the client to +// databasePluginRPCClient implements Database and is used on the client to // make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index e28a8741e..9ab35b740 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -9,10 +9,10 @@ import ( // ---- Tracing Middleware Domain ---- -// databaseTracingMiddleware wraps a implementation of DatabaseType and executes +// databaseTracingMiddleware wraps a implementation of Database and executes // trace logging on function call. type databaseTracingMiddleware struct { - next DatabaseType + next Database logger log.Logger typeStr string @@ -79,10 +79,10 @@ func (mw *databaseTracingMiddleware) Close() (err error) { // ---- Metrics Middleware Domain ---- -// databaseMetricsMiddleware wraps an implementation of DatabaseTypes and on +// databaseMetricsMiddleware wraps an implementation of Databases and on // function call logs metrics about this instance. type databaseMetricsMiddleware struct { - next DatabaseType + next Database typeStr string } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 9a6691fba..21812423c 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -10,8 +10,8 @@ import ( log "github.com/mgutz/logxi/v1" ) -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { +// Database is the interface that all database objects must implement. +type Database interface { Type() (string, error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) RenewUser(statements Statements, username string, expiration time.Time) error @@ -31,24 +31,24 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. -func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { +func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (Database, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } - var db DatabaseType + var db Database if pluginRunner.Builtin { // Plugin is builtin so we can retrieve an instance of the interface - // from the pluginRunner. Then cast it to a DatabaseType. + // from the pluginRunner. Then cast it to a Database. dbRaw, err := pluginRunner.BuiltinFactory() if err != nil { return nil, fmt.Errorf("error getting plugin type: %s", err) } var ok bool - db, ok = dbRaw.(DatabaseType) + db, ok = dbRaw.(Database) if !ok { return nil, fmt.Errorf("unsuported database type: %s", pluginName) } @@ -95,7 +95,7 @@ var handshakeConfig = plugin.HandshakeConfig{ // DatabasePlugin implements go-plugin's Plugin interface. It has methods for // retrieving a server and a client instance of the plugin. type DatabasePlugin struct { - impl DatabaseType + impl Database } func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 3a3e23394..04cc3d7e9 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -8,9 +8,9 @@ import ( ) // NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implementation in a databasePluginRPCServer object and starts a +// Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func NewPluginServer(db DatabaseType) { +func NewPluginServer(db Database) { dbPlugin := &DatabasePlugin{ impl: db, } @@ -35,10 +35,10 @@ func NewPluginServer(db DatabaseType) { // ---- RPC server domain ---- -// databasePluginRPCServer implements an RPC version of DatabaseType and is run -// inside a plugin. It wraps an underlying implementation of DatabaseType. +// databasePluginRPCServer implements an RPC version of Database and is run +// inside a plugin. It wraps an underlying implementation of Database. type databasePluginRPCServer struct { - impl DatabaseType + impl Database } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 2a0022b4d..f154ae164 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -16,7 +16,7 @@ var ( respErrEmptyName = logical.ErrorResponse("Empty name attribute given") ) -// DatabaseConfig is used by the Factory function to configure a DatabaseType +// DatabaseConfig is used by the Factory function to configure a Database // object. type DatabaseConfig struct { PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index b0e0ab6d4..54f2a9711 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -15,7 +15,7 @@ import ( const msSQLTypeName = "mssql" -// MSSQL is an implementation of DatabaseType interface +// MSSQL is an implementation of Database interface type MSSQL struct { connutil.ConnectionProducer credsutil.CredentialsProducer diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go index 6de3299e3..c43691c61 100644 --- a/plugins/helper/database/connutil/connutil.go +++ b/plugins/helper/database/connutil/connutil.go @@ -9,7 +9,7 @@ var ( errNotInitialized = errors.New("connection has not been initalized") ) -// ConnectionProducer can be used as an embeded interface in the DatabaseType +// ConnectionProducer can be used as an embeded interface in the Database // definition. It implements the methods dealing with individual database // connections and is used in all the builtin database types. type ConnectionProducer interface { diff --git a/plugins/helper/database/credsutil/credsutil.go b/plugins/helper/database/credsutil/credsutil.go index 7f388a0f7..bc35617ac 100644 --- a/plugins/helper/database/credsutil/credsutil.go +++ b/plugins/helper/database/credsutil/credsutil.go @@ -2,7 +2,7 @@ package credsutil import "time" -// CredentialsProducer can be used as an embeded interface in the DatabaseType +// CredentialsProducer can be used as an embeded interface in the Database // definition. It implements the methods for generating user information for a // particular database type and is used in all the builtin database types. type CredentialsProducer interface { From f25b3677325888ce4d17dbe496a9bfb3ce772aee Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 14:03:48 -0700 Subject: [PATCH 096/162] Don't uppercase ErrorResponses --- builtin/logical/database/path_config_connection.go | 8 ++++---- builtin/logical/database/path_role_create.go | 2 +- builtin/logical/database/path_roles.go | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f154ae164..965364dc5 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -12,8 +12,8 @@ import ( ) var ( - respErrEmptyPluginName = logical.ErrorResponse("Empty plugin name") - respErrEmptyName = logical.ErrorResponse("Empty name attribute given") + respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyName = logical.ErrorResponse("empty name attribute given") ) // DatabaseConfig is used by the Factory function to configure a Database @@ -199,13 +199,13 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } err = db.Initialize(config.ConnectionDetails, verifyConnection) if err != nil { db.Close() - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } // Grab the mutex lock diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 631802dff..a8da211f2 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -38,7 +38,7 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { return nil, err } if role == nil { - return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index b3393b1ba..e85b123dc 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -136,12 +136,12 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty role name attribute given"), nil + return logical.ErrorResponse("empty role name attribute given"), nil } dbName := data.Get("db_name").(string) if dbName == "" { - return logical.ErrorResponse("Empty database name attribute given"), nil + return logical.ErrorResponse("empty database name attribute given"), nil } // Get statements @@ -157,12 +157,12 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { defaultTTL, err := time.ParseDuration(defaultTTLRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Invalid default_ttl: %s", err)), nil + "invalid default_ttl: %s", err)), nil } maxTTL, err := time.ParseDuration(maxTTLRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_ttl: %s", err)), nil + "invalid max_ttl: %s", err)), nil } statements := dbplugin.Statements{ From cb1f1d418c618abef5eac067434bf98bb7663516 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 16:20:20 -0700 Subject: [PATCH 097/162] Only run Abs on the plugin directory if it's set --- vault/core.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vault/core.go b/vault/core.go index 260ce096e..f3ed06d38 100644 --- a/vault/core.go +++ b/vault/core.go @@ -466,9 +466,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { } var err error - c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) - if err != nil { - return nil, fmt.Errorf("core setup failed: %v", err) + if conf.PluginDirectory != "" { + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, fmt.Errorf("core setup failed, could not verify plugin directory: %v", err) + } } // Construct a new AES-GCM barrier From e4e61ec18ca4b4001ed01dd2a4d809eaff6ed864 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 18:31:27 -0700 Subject: [PATCH 098/162] return a 404 when no plugin is found --- vault/dynamic_system_view.go | 11 ++++++++++- vault/logical_system.go | 3 +++ vault/plugin_catalog.go | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index edac20140..3844b46bf 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -1,6 +1,7 @@ package vault import ( + "fmt" "time" "github.com/hashicorp/vault/helper/consts" @@ -120,7 +121,15 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { - return d.core.pluginCatalog.Get(name) + r, err := d.core.pluginCatalog.Get(name) + if err != nil { + return nil, err + } + if r == nil { + return nil, fmt.Errorf("no plugin found with name: %s", name) + } + + return r, nil } // MlockEnabled returns the configuration setting for enabling mlock on plugins. diff --git a/vault/logical_system.go b/vault/logical_system.go index 109100090..4dd66f814 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -815,6 +815,9 @@ func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framewo if err != nil { return nil, err } + if plugin == nil { + return nil, nil + } return &logical.Response{ Data: map[string]interface{}{ diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 5d88873b3..095d81b1e 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -71,7 +71,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { }, nil } - return nil, fmt.Errorf("no plugin found with name: %s", name) + return nil, nil } // Set registers a new external plugin with the catalog, or updates an existing From b52b410a477d6eb95035e1b7b687cb4776df9f32 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 21:24:19 -0700 Subject: [PATCH 099/162] Update test to reflect the correct read response --- vault/logical_system_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 7bedf7cd6..aa2ce449a 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1198,7 +1198,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) - if err == nil { - t.Fatalf("expected error, plugin not deleted correctly") + if resp != nil || err != nil { + t.Fatalf("expected nil response, plugin not deleted correctly got resp: %v, err: %v", resp, err) } } From bed1c17b1ed57d84ed4c7f5865833b484b090ee5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:24:19 -0700 Subject: [PATCH 100/162] Update logging to new structure --- .../database/dbplugin/databasemiddleware.go | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 9ab35b740..13591e516 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -25,10 +25,10 @@ func (mw *databaseTracingMiddleware) Type() (string, error) { func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) } return mw.next.CreateUser(statements, usernamePrefix, expiration) } @@ -36,10 +36,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameP func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) } return mw.next.RenewUser(statements, username, expiration) } @@ -47,10 +47,10 @@ func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username s func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) } return mw.next.RevokeUser(statements, username) } @@ -58,10 +58,10 @@ func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) } return mw.next.Initialize(conf, verifyConnection) } @@ -69,10 +69,10 @@ func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, ver func (mw *databaseTracingMiddleware) Close() (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/Close: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) } return mw.next.Close() } From 3d3e4eb5a4eb9a7440c20460ac40f785fe13fb80 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:26:23 -0700 Subject: [PATCH 101/162] Use TypeCommaStringSlice for allowed_roles --- builtin/logical/database/path_config_connection.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 965364dc5..557d3f3cb 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -98,10 +98,10 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { }, "allowed_roles": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of the role names allowed to - get creds from this database connection. If not set all roles - are allowed.`, + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or array of the role names + allowed to get creds from this database connection. If not set + all roles are allowed.`, }, }, From eb0f831d6a6f81d4c21f84419144e60d876bd531 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:39:17 -0700 Subject: [PATCH 102/162] Rename path_role_create to path_creds_create --- builtin/logical/database/backend.go | 2 +- .../{path_role_create.go => path_creds_create.go} | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) rename builtin/logical/database/{path_role_create.go => path_creds_create.go} (84%) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 7d6ffe9c9..e8cf98ebb 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -27,7 +27,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), - pathRoleCreate(&b), + pathCredsCreate(&b), pathResetConnection(&b), }, diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_creds_create.go similarity index 84% rename from builtin/logical/database/path_role_create.go rename to builtin/logical/database/path_creds_create.go index a8da211f2..341c61d67 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathRoleCreate(b *databaseBackend) *framework.Path { +func pathCredsCreate(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: "creds/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -20,15 +20,15 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleCreateRead(), + logical.ReadOperation: b.pathCredsCreateRead(), }, - HelpSynopsis: pathRoleCreateReadHelpSyn, - HelpDescription: pathRoleCreateReadHelpDesc, + HelpSynopsis: pathCredsCreateReadHelpSyn, + HelpDescription: pathCredsCreateReadHelpDesc, } } -func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { +func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) @@ -58,7 +58,6 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { // Get the Database object db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - // TODO: return a resp error instead? return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } @@ -82,11 +81,11 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { } } -const pathRoleCreateReadHelpSyn = ` +const pathCredsCreateReadHelpSyn = ` Request database credentials for a certain role. ` -const pathRoleCreateReadHelpDesc = ` +const pathCredsCreateReadHelpDesc = ` This path reads database credentials for a certain role. The database credentials will be generated on demand and will be automatically revoked when the lease is up. From 207d01fd39b5efeeeb9de39882bfa989e41f49f9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 11:11:10 -0700 Subject: [PATCH 103/162] Update the connection details data and fix allowedRoles --- builtin/logical/database/path_config_connection.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 557d3f3cb..7c175848f 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,7 +6,6 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -187,9 +186,14 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { verifyConnection := data.Get("verify_connection").(bool) - // Pasrse and dedupe allowed roles from a comma separated string. - allowedRolesRaw := data.Get("allowed_roles").(string) - allowedRoles := strutil.ParseDedupAndSortStrings(allowedRolesRaw, ",") + allowedRoles := data.Get("allowed_roles").([]string) + + // Remove these entries from the data before we store it keyed under + // ConnectionDetails. + delete(data.Raw, "name") + delete(data.Raw, "plugin_name") + delete(data.Raw, "allowed_roles") + delete(data.Raw, "verify_connection") config := &DatabaseConfig{ ConnectionDetails: data.Raw, From e3e5f12f9e293d455304b0b2cadc098c36d92cc5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 11:48:24 -0700 Subject: [PATCH 104/162] Default deny when allowed roles is empty --- builtin/logical/database/backend_test.go | 84 +++++++++++++++++-- .../database/path_config_connection.go | 4 +- builtin/logical/database/path_creds_create.go | 2 +- 3 files changed, 80 insertions(+), 10 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 2ece767fc..08317cbdc 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -113,6 +113,7 @@ func TestBackend_config_connection(t *testing.T) { "connection_url": "sample_connection_url", "plugin_name": "postgresql-database-plugin", "verify_connection": false, + "allowed_roles": []string{"*"}, } configReq := &logical.Request{ @@ -127,9 +128,11 @@ func TestBackend_config_connection(t *testing.T) { } expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": configData, - "allowed_roles": []string{}, + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_connection_url", + }, + "allowed_roles": []string{"*"}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(configReq) @@ -164,6 +167,7 @@ func TestBackend_basic(t *testing.T) { data := map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -290,6 +294,7 @@ func TestBackend_connectionCrud(t *testing.T) { data = map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, } req = &logical.Request{ Operation: logical.UpdateOperation, @@ -304,9 +309,11 @@ func TestBackend_connectionCrud(t *testing.T) { // Read connection expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": data, - "allowed_roles": []string{}, + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": connURL, + }, + "allowed_roles": []string{"plugin-role-test"}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(req) @@ -506,7 +513,6 @@ func TestBackend_allowedRoles(t *testing.T) { data := map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", - "allowed_roles": "allow, allowed", } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -567,6 +573,70 @@ func TestBackend_allowedRoles(t *testing.T) { t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) } + // update connection with * allowed roles connection + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "*", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds, should work. + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/allowed", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // update connection with allowed roles + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "allow, allowed", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds from denied role, should fail + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/denied", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != logical.ErrPermissionDenied { + t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) + } + // Get creds from allowed role, should work. data = map[string]interface{}{} req = &logical.Request{ diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 7c175848f..f52cfec59 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -99,8 +99,8 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { "allowed_roles": &framework.FieldSchema{ Type: framework.TypeCommaStringSlice, Description: `Comma separated string or array of the role names - allowed to get creds from this database connection. If not set - all roles are allowed.`, + allowed to get creds from this database connection. If empty no + roles are allowed. If "*" all roles are allowed.`, }, }, diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 341c61d67..9bbaceb54 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -48,7 +48,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // If role name isn't in the database's allowed roles, send back a // permission denied. - if len(dbConfig.AllowedRoles) > 0 && !strutil.StrListContains(dbConfig.AllowedRoles, name) { + if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) { return nil, logical.ErrPermissionDenied } From 892812d67d149dad08090f37e03ddeabf62d6735 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:02:37 -0700 Subject: [PATCH 105/162] Change ttl types to TypeDurationSecond --- builtin/logical/database/path_roles.go | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index e85b123dc..c81261804 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -1,7 +1,6 @@ package database import ( - "fmt" "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" @@ -65,12 +64,12 @@ func pathRoles(b *databaseBackend) *framework.Path { }, "default_ttl": { - Type: framework.TypeString, + Type: framework.TypeDurationSecond, Description: "Default ttl for role.", }, "max_ttl": { - Type: framework.TypeString, + Type: framework.TypeDurationSecond, Description: "Maximum time a credential is valid for", }, }, @@ -114,8 +113,8 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc { "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, "renew_statements": role.Statements.RenewStatements, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), + "default_ttl": role.DefaultTTL.Seconds(), + "max_ttl": role.MaxTTL.Seconds(), }, }, nil } @@ -151,19 +150,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { renewStmts := data.Get("renew_statements").(string) // Get TTLs - defaultTTLRaw := data.Get("default_ttl").(string) - maxTTLRaw := data.Get("max_ttl").(string) - - defaultTTL, err := time.ParseDuration(defaultTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "invalid default_ttl: %s", err)), nil - } - maxTTL, err := time.ParseDuration(maxTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "invalid max_ttl: %s", err)), nil - } + defaultTTLRaw := data.Get("default_ttl").(int) + maxTTLRaw := data.Get("max_ttl").(int) + defaultTTL := time.Duration(defaultTTLRaw) * time.Second + maxTTL := time.Duration(maxTTLRaw) * time.Second statements := dbplugin.Statements{ CreationStatements: creationStmts, From 4782d9d2aff9f8969476aac9dc33e9948c953541 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:29:16 -0700 Subject: [PATCH 106/162] Update the error messages for renew and revoke --- builtin/logical/database/secret_creds.go | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index ffc59cf3f..2704eb287 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -38,7 +38,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { return nil, err } if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("error during renew: could not find role with name %s", req.Secret.InternalData["role"]) } f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) @@ -54,7 +54,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Get our connection db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) + return nil, fmt.Errorf("error during renew: %s", err) } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -90,25 +90,9 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { return nil, err } if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"]) } - /* TODO: think about how to handle this case. - if !ok { - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) - } else { - revocationSQL = role.RevocationStatement - } - }*/ - // Grab the read lock b.Lock() defer b.Unlock() @@ -116,7 +100,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Get our connection db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) + return nil, fmt.Errorf("error during revoke: %s", err) } err = db.RevokeUser(role.Statements, username) From 6a1ae9160dd712acd795998ba29eb3f5c9346527 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:34:45 -0700 Subject: [PATCH 107/162] Add mssql builtin plugin type --- helper/builtinplugins/builtin.go | 2 ++ plugins/database/mssql/mssql.go | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index b61a51710..c20a92603 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,6 +1,7 @@ package builtinplugins import ( + "github.com/hashicorp/vault/plugins/database/mssql" "github.com/hashicorp/vault/plugins/database/mysql" "github.com/hashicorp/vault/plugins/database/postgresql" ) @@ -10,6 +11,7 @@ type BuiltinFactory func() (interface{}, error) var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ "mysql-database-plugin": mysql.New, "postgresql-database-plugin": postgresql.New, + "mssql-database-plugin": mssql.New, } func Get(name string) (BuiltinFactory, bool) { diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 54f2a9711..48da8ff08 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -21,7 +21,7 @@ type MSSQL struct { credsutil.CredentialsProducer } -func New() *MSSQL { +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = msSQLTypeName @@ -35,14 +35,17 @@ func New() *MSSQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MSSQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*MSSQL)) return nil } From 6252f48dfed08140b1a7c9efcafd37c1380821cc Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:52:10 -0700 Subject: [PATCH 108/162] Fix MSSQL test --- plugins/database/mssql/mssql_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 512033bd7..0dc18cb3e 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -27,7 +27,8 @@ func TestMSSQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { @@ -55,7 +56,8 @@ func TestMSSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -91,7 +93,8 @@ func TestMSSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) From d0cad5345ab528fd510ee4226ba1ae82b3d36233 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 15:23:14 -0700 Subject: [PATCH 109/162] Update to a RWMutex --- builtin/logical/database/backend.go | 20 +++++---- .../database/path_config_connection.go | 2 +- builtin/logical/database/path_creds_create.go | 20 ++++++--- builtin/logical/database/secret_creds.go | 42 ++++++++++++++----- 4 files changed, 58 insertions(+), 26 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e8cf98ebb..2aff47375 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -50,7 +50,7 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.Mutex + sync.RWMutex } // resetAllDBs closes all connections from all database types @@ -66,21 +66,23 @@ func (b *databaseBackend) closeAllDBs() { } // This function is used to retrieve a database object either from the cached -// connection map or by using the database config in storage. The caller of this -// function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) { - // if the object already is built and cached, return it +// connection map. The caller of this function needs to hold the backend's read +// lock. +func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { db, ok := b.connections[name] - if ok { - return db, nil - } + return db, ok +} +// This function creates a new db object from the stored configuration and +// caches it in the connections map. The caller of this function needs to hold +// the backend's write lock +func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { config, err := b.DatabaseConfig(s, name) if err != nil { return nil, err } - db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) + db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f52cfec59..39eb3d000 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -62,7 +62,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { b.clearConnection(name) // Execute plugin again, we don't need the object so throw away. - _, err := b.getOrCreateDBObj(req.Storage, name) + _, err := b.createDBObj(req.Storage, name) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 9bbaceb54..60f0c5e3e 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -52,13 +52,23 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } - b.Lock() - defer b.Unlock() + b.RLock() // Get the Database object - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } expiration := time.Now().Add(role.DefaultTTL) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2704eb287..690b41565 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -48,13 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during renew: %s", err) + // Get the Database object + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -94,13 +104,23 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during revoke: %s", err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } err = db.RevokeUser(role.Statements, username) From 081101c7cf97853109a97c57ee53846af9fd0679 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 15:55:34 -0700 Subject: [PATCH 110/162] Add an error check to reset a plugin if it is closed --- builtin/logical/database/backend.go | 10 ++++++++++ .../logical/database/path_config_connection.go | 12 ++---------- builtin/logical/database/path_creds_create.go | 10 +++++++--- builtin/logical/database/secret_creds.go | 18 ++++++++++++------ 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 2aff47375..4bac4b0c3 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "net/rpc" "strings" "sync" @@ -152,6 +153,15 @@ func (b *databaseBackend) clearConnection(name string) { } } +func (b *databaseBackend) closeIfShutdown(name string, err error) { + // Plugin has shutdown, close it so next call can reconnect. + if err == rpc.ErrShutdown { + b.Lock() + b.clearConnection(name) + b.Unlock() + } +} + const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 39eb3d000..4c0863fd7 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -216,16 +216,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Lock() defer b.Unlock() - if _, ok := b.connections[name]; ok { - // Close and remove the old connection - err := b.connections[name].Close() - if err != nil { - db.Close() - return nil, err - } - - delete(b.connections, name) - } + // Close and remove the old connection + b.clearConnection(name) // Save the new connection b.connections[name] = db diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 60f0c5e3e..7bc7dfa6f 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -52,7 +52,9 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } + // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -60,22 +62,24 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } expiration := time.Now().Add(role.DefaultTTL) // Create the user username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 690b41565..c3dfcb973 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -49,6 +49,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -56,21 +57,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { err := db.RenewUser(role.Statements, username, expireTime) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } } @@ -105,6 +108,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get our connection db, ok := b.getDBObj(role.DBName) @@ -112,19 +116,21 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } err = db.RevokeUser(role.Statements, username) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } From 53752c30022860fe1b022e687921cf7a748f8276 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 16:43:42 -0700 Subject: [PATCH 111/162] Add check to ensure we don't overwrite existing connections --- builtin/logical/database/backend.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4bac4b0c3..da8c8384a 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -78,12 +78,17 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { // caches it in the connections map. The caller of this function needs to hold // the backend's write lock func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { + db, ok := b.connections[name] + if ok { + return db, nil + } + config, err := b.DatabaseConfig(s, name) if err != nil { return nil, err } - db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) + db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } From fadf6c439fe4a693b51d74ff6d1048184d1dbfc4 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Thu, 27 Apr 2017 11:07:52 -0400 Subject: [PATCH 112/162] Update New() func signature and its references --- plugins/database/cassandra/cassandra.go | 11 +++++++---- plugins/database/cassandra/cassandra_test.go | 12 ++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 15df0352e..24d87a353 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -24,7 +24,7 @@ type Cassandra struct { credsutil.CredentialsProducer } -func New() *Cassandra { +func New() (interface{}, error) { connProducer := &connutil.CassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -35,14 +35,17 @@ func New() *Cassandra { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*Cassandra)) return nil } diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index b81c32710..9e98ec48f 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -80,7 +80,8 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -109,7 +110,8 @@ func TestCassandra_CreateUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -140,7 +142,8 @@ func TestMyCassandra_RenewUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -176,7 +179,8 @@ func TestCassandra_RevokeUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) From c11f2638b92e62a226598b7d40a1f35dd84c242e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 22:56:06 -0700 Subject: [PATCH 113/162] If user provides a revocation statement for MSSQL plugin honor it --- plugins/database/mssql/mssql.go | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 48da8ff08..a0d863080 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -142,6 +142,51 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir // then kill pending connections from that user, and finally drop the user and login from the // database instance. func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { + if statements.RevocationStatements == "" { + return m.revokeUserDefault(username) + } + + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (m *MSSQL) revokeUserDefault(username string) error { // Get connection db, err := m.getConnection() if err != nil { From 9a07675d861e2c6d40286239ceb37dd6b8beb94d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 22:59:22 -0700 Subject: [PATCH 114/162] Update username length for MSSQL --- plugins/database/mssql/mssql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index a0d863080..b608428e5 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -26,8 +26,8 @@ func New() (interface{}, error) { connProducer.Type = msSQLTypeName credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, + DisplayNameLen: 20, + UsernameLen: 128, } dbType := &MSSQL{ From 43cf6198714e0fcfacacc5a4a90ac7b75e2670a2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 23:02:33 -0700 Subject: [PATCH 115/162] Update the username length for postgresql --- plugins/database/postgresql/postgresql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 5781b6c3d..e90e0f8cb 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -22,8 +22,8 @@ func New() (interface{}, error) { connProducer.Type = postgreSQLTypeName credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, + DisplayNameLen: 10, + UsernameLen: 63, } dbType := &PostgreSQL{ From 5076701bea42724d05e569d8f89f0b6c3dc65ed7 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 1 May 2017 11:27:35 -0400 Subject: [PATCH 116/162] Honor statements for RevokeUser on Cassandra backend, add method comments --- plugins/database/cassandra/cassandra.go | 44 ++++++++++++++----- .../cassandra/test-fixtures/cassandra.yaml | 8 ++-- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 24d87a353..bf1cbab92 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -1,11 +1,11 @@ package cassandra import ( - "fmt" "strings" "time" "github.com/gocql/gocql" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -14,16 +14,18 @@ import ( ) const ( - defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultRollbackCQL = `DROP USER '{{username}}';` - cassandraTypeName = "cassandra" + defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultUserDeletionCQL = `DROP USER '{{username}}';` + cassandraTypeName = "cassandra" ) +// Cassandra is an implementation of Database interface type Cassandra struct { connutil.ConnectionProducer credsutil.CredentialsProducer } +// New returns a new Cassandra instance func New() (interface{}, error) { connProducer := &connutil.CassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -38,7 +40,7 @@ func New() (interface{}, error) { return dbType, nil } -// Run instantiates a MySQL object, and runs the RPC server for the plugin +// Run instantiates a Cassandra object, and runs the RPC server for the plugin func Run() error { dbType, err := New() if err != nil { @@ -50,6 +52,7 @@ func Run() error { return nil } +// Type returns the TypeName for this backend func (c *Cassandra) Type() (string, error) { return cassandraTypeName, nil } @@ -63,6 +66,8 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } +// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by +// the CreationStatement provided. func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() @@ -76,11 +81,11 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st creationCQL := statements.CreationStatements if creationCQL == "" { - creationCQL = defaultCreationCQL + creationCQL = defaultUserCreationCQL } rollbackCQL := statements.RollbackStatements if rollbackCQL == "" { - rollbackCQL = defaultRollbackCQL + rollbackCQL = defaultUserDeletionCQL } username, err = c.GenerateUsername(usernamePrefix) @@ -113,7 +118,6 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st session.Query(dbutil.QueryHelper(query, map[string]string{ "username": username, - "password": password, })).Exec() } return "", "", err @@ -123,11 +127,13 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st return username, password, nil } +// RenewUser is not supported on Cassandra, so this is a no-op. func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } +// RevokeUser attempts to drop the specified user. func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { // Grab the lock c.Lock() @@ -138,10 +144,24 @@ func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) return err } - err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() - if err != nil { - return fmt.Errorf("error removing user '%s': %s", username, err) + revocationCQL := statements.RevocationStatements + if revocationCQL == "" { + revocationCQL = defaultUserDeletionCQL } - return nil + var result *multierror.Error + for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err := session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + })).Exec() + + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() } diff --git a/plugins/database/cassandra/test-fixtures/cassandra.yaml b/plugins/database/cassandra/test-fixtures/cassandra.yaml index 5b12c8cf4..54f47d34a 100644 --- a/plugins/database/cassandra/test-fixtures/cassandra.yaml +++ b/plugins/database/cassandra/test-fixtures/cassandra.yaml @@ -421,7 +421,7 @@ seed_provider: parameters: # seeds is actually a comma-delimited list of addresses. # Ex: ",," - - seeds: "172.17.0.3" + - seeds: "172.17.0.2" # For workloads with more data than can fit in memory, Cassandra's # bottleneck will be reads that need to fetch data from @@ -572,7 +572,7 @@ ssl_storage_port: 7001 # # Setting listen_address to 0.0.0.0 is always wrong. # -listen_address: 172.17.0.3 +listen_address: 172.17.0.2 # Set listen_address OR listen_interface, not both. Interfaces must correspond # to a single address, IP aliasing is not supported. @@ -586,7 +586,7 @@ listen_address: 172.17.0.3 # Address to broadcast to other Cassandra nodes # Leaving this blank will set it to the same value as listen_address -broadcast_address: 172.17.0.3 +broadcast_address: 172.17.0.2 # When using multiple physical network interfaces, set this # to true to listen on broadcast_address in addition to @@ -668,7 +668,7 @@ rpc_port: 9160 # be set to 0.0.0.0. If left blank, this will be set to the value of # rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must # be set. -broadcast_rpc_address: 172.17.0.3 +broadcast_rpc_address: 172.17.0.2 # enable or disable keepalive on rpc/native connections rpc_keepalive: true From 9a60ec9fda6453a18dbe2f6594b67e5cf4d5c2e1 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 14:59:55 -0700 Subject: [PATCH 117/162] Update interface name from Wrapper to a more descriptive RunnerUtil --- builtin/logical/database/dbplugin/client.go | 2 +- builtin/logical/database/dbplugin/plugin.go | 2 +- helper/pluginutil/runner.go | 8 ++++---- helper/pluginutil/tls.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 8cfc3aad0..0c095f891 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -29,7 +29,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (Database, error) { +func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner) (Database, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 21812423c..941f7aa04 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -31,7 +31,7 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. -func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (Database, error) { +func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 6a8df7385..0617f7624 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -19,15 +19,15 @@ type Looker interface { // Wrapper interface defines the functions needed by the runner to wrap the // metadata needed to run a plugin process. This includes looking up Mlock // configuration and wrapping data in a respose wrapped token. -type Wrapper interface { +type RunnerUtil interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool } // LookWrapper defines the functions for both Looker and Wrapper -type LookWrapper interface { +type LookRunnerUtil interface { Looker - Wrapper + RunnerUtil } // PluginRunner defines the metadata needed to run a plugin securely with @@ -43,7 +43,7 @@ type PluginRunner struct { // Run takes a wrapper instance, and the go-plugin paramaters and executes a // plugin. -func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { +func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate certBytes, key, err := GenerateCert() if err != nil { diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index ee0c54d89..05804a33b 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -97,7 +97,7 @@ func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { +func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err From b3819c433b981a23146d0b6e45516b065230db72 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:30:56 -0700 Subject: [PATCH 118/162] Don't store an error response as a package variable --- builtin/logical/database/path_config_connection.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4c0863fd7..f37674285 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -11,8 +11,8 @@ import ( ) var ( - respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") - respErrEmptyName = logical.ErrorResponse("empty name attribute given") + respErrEmptyPluginName = "empty plugin name" + respErrEmptyName = "empty name attribute given" ) // DatabaseConfig is used by the Factory function to configure a Database @@ -51,7 +51,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } // Grab the mutex lock @@ -120,7 +120,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name)) @@ -146,7 +146,7 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } err := req.Storage.Delete(fmt.Sprintf("config/%s", name)) @@ -176,12 +176,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { pluginName := data.Get("plugin_name").(string) if pluginName == "" { - return respErrEmptyPluginName, nil + return logical.ErrorResponse(respErrEmptyPluginName), nil } name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } verifyConnection := data.Get("verify_connection").(bool) From 0e70ba8dbc5dbb80187824bba8cf796158570ba5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:43:21 -0700 Subject: [PATCH 119/162] Add test for custiom mssql revoke statement --- plugins/database/mssql/mssql_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 0dc18cb3e..830e38abb 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -122,6 +122,26 @@ func TestMSSQL_RevokeUser(t *testing.T) { if err := testCredsExist(t, connURL, username, password); err == nil { t.Fatal("Credentials were not revoked") } + + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statememt + statements.RevocationStatements = testMSSQLDrop + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } } func testCredsExist(t testing.TB, connURL, username, password string) error { @@ -140,3 +160,8 @@ const testMSSQLRole = ` CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` + +const testMSSQLDrop = ` +DROP USER [{{name}}]; +DROP LOGIN [{{name}}]; +` From 98e111d4cd3fb214492b1018843755f5068a3ace Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:45:17 -0700 Subject: [PATCH 120/162] Prepend a 'v-' to the sql username strings --- plugins/helper/database/credsutil/sql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go index 23e98102f..a7929ccb1 100644 --- a/plugins/helper/database/credsutil/sql.go +++ b/plugins/helper/database/credsutil/sql.go @@ -21,7 +21,7 @@ func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, if err != nil { return "", err } - username := fmt.Sprintf("%s-%s", displayName, userUUID) + username := fmt.Sprintf("v-%s-%s", displayName, userUUID) if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { username = username[:scp.UsernameLen] } From a96309774761a0166bd98dd55023d19f9d982d7a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 01:59:36 -0700 Subject: [PATCH 121/162] Add internals doc for plugins --- website/source/docs/internals/plugins.html.md | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 website/source/docs/internals/plugins.html.md diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md new file mode 100644 index 000000000..5c396573d --- /dev/null +++ b/website/source/docs/internals/plugins.html.md @@ -0,0 +1,102 @@ +--- +layout: "docs" +page_title: "Plugin System" +sidebar_current: "docs-internals-plugins" +description: |- + Learn about Vault's plugin system. +--- + +# Plugin System +Certain Vault backends utilize plugins to extend their functionality outside of +what is available in the core vault code. Often times these backends will +provide both builtin plugins and a mechanism for executing external plugins. +Builtin plugins are shipped with vault, often for commonly used implementations, +and require no additional operator intervention to run. Builtin plugins are +just like any other backend code inside vault. External plugins, on the other +hand, are not shipped with the vault binary and must be registered to vault by +a privileged vault user. This section of the documentation will describe the +architecture and security of external plugins. + +# Plugin Architecture +Vault's plugins are completely separate, standalone applications that Vault +executes and communicates with over RPC. This means the plugin process does not +share the same memory space as Vault and therefore can only access the +interfaces and arguments given to it. This also means a crash in a plugin can not +crash the entirety of Vault. + +## Plugin Communication +Vault creates a mutually authenticated TLS connection for communication with the +plugin's RPC server. While invoking the plugin process Vault passes a [wrapping +token](https://www.vaultproject.io/docs/concepts/response-wrapping.html) to the +plugin process' environment. This token is single use and has a short TTL. Once +unwrapped, it provides the plugin with a unique generated TLS certificate and +private key for it to use to talk to the original vault process. + +## Plugin Registration +An important aspect of Vault's plugin system is designed to ensure the plugin +invoked by vault is authentic and maintains integrity. There are two components +that a Vault operator needs to configure before external plugins can be run. + +### Plugin Directory +The plugin directory is a configuration option of Vault, and can be specified in +the [configuration file](https://www.vaultproject.io/docs/configuration/index.html). +This setting specifies a directory that all plugin binaries must live. A plugin +can not be added to vault unless it exists in the plugin directory. There is no +default for this configuration option, and if it is not set plugins can not be +added to vault. + +~> Warning: A vault operator should take care to lock down the permissions on +this directory to ensure a plugin can not be modified by an unauthorized user +between the time of the SHA check and the time of plugin execution. + +### Plugin Catalog +The plugin catalog is Vault's list of approved plugins. The catalog is stored in +Vault's barrier and can only be updated by a vault user with sudo permissions. +Upon adding a new plugin the SHA256 sum of the executable and the command that +should be used to run the plugin must be provided. The catalog will make sure +the executable referenced in the command exists in the plugin directory. When +added to the catalog the plugin is not automatically executed, it instead +becomes visible to backends and can be executed by them. + +### Plugin Execution +When a backend executes a plugin it first checks the executable's SHA256 sum +against the one configured in the plugin catalog. Like Vault, plugins support +the use of mlock when availible. + +# Plugin Development +Because Vault communicates to plugins over a RPC interface, you can build and +distribute a plugin for Vault without having to rebuild Vault itself. This makes +it easy for you to build a Vault plugin for your organization's internal use, +for a proprietary API that you don't want to open source, or to prototype +something before contributing it back to the main project. + +In theory, because the plugin interface is HTTP, you could even develop a plugin +using a completely different programming language! (Disclaimer, you would also +have to re-implement the plugin API which is not a trivial amount of work.) + +~> Advanced topic! Plugin development is a highly advanced topic in Vault, and +is not required knowledge for day-to-day usage. If you don't plan on writing any +plugins, we recommend not reading this section of the documentation. + +Developing a plugin is simple. The only knowledge necessary to write +a plugin is basic command-line skills and basic knowledge of the +[Go programming language](http://golang.org). + +You're plugin implementation just needs to satisfy the interface for the plugin +type you want to build. You can find these definitions in the docs for the +backend running the plugin. + +```go +package main + +import ( + plugin "github.com/hashicorp/vault/builtin/logcial/database/dbplugin" +) + +func main() { + plugin.Serve(new(MyPlugin)) +} +``` + +And that's basically it! You would just need to change MyPlugin to your actual +plugin. From f17c50108fc7a2b6a72b50eb1cd88bc1422a43f6 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:00:04 -0700 Subject: [PATCH 122/162] Add plugins interal page to the sidebar: --- website/source/layouts/docs.erb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index 8f2686e64..32e2a7e7a 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -35,6 +35,10 @@ > Replication + + > + Plugins + From c8bbea9f37713e4d31fe789a50259224afe967c2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:00:39 -0700 Subject: [PATCH 123/162] Rename NewPluginServer to just Serve --- builtin/logical/database/dbplugin/server.go | 4 ++-- plugins/database/mssql/mssql.go | 2 +- plugins/database/mysql/mysql.go | 2 +- plugins/database/postgresql/postgresql.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 04cc3d7e9..32c377e13 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -7,10 +7,10 @@ import ( "github.com/hashicorp/vault/helper/pluginutil" ) -// NewPluginServer is called from within a plugin and wraps the provided +// Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func NewPluginServer(db Database) { +func Serve(db Database) { dbPlugin := &DatabasePlugin{ impl: db, } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index b608428e5..d82efce6f 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -45,7 +45,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*MSSQL)) + dbplugin.Serve(dbType.(*MSSQL)) return nil } diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 6485aaa86..7eb680759 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -48,7 +48,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*MySQL)) + dbplugin.Serve(dbType.(*MySQL)) return nil } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index e90e0f8cb..0889a86f5 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -41,7 +41,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*PostgreSQL)) + dbplugin.Serve(dbType.(*PostgreSQL)) return nil } From ca7ff89bcb4a4d3f1e62d7f70e4611dcce9f9810 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:22:06 -0700 Subject: [PATCH 124/162] Fix documentation --- plugins/database/postgresql/postgresql.go | 2 +- website/source/docs/internals/plugins.html.md | 20 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 0889a86f5..bc5b14544 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -34,7 +34,7 @@ func New() (interface{}, error) { return dbType, nil } -// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin +// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { dbType, err := New() if err != nil { diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index 5c396573d..a3baafff0 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -33,7 +33,7 @@ unwrapped, it provides the plugin with a unique generated TLS certificate and private key for it to use to talk to the original vault process. ## Plugin Registration -An important aspect of Vault's plugin system is designed to ensure the plugin +An important consideration of Vault's plugin system is to ensure the plugin invoked by vault is authentic and maintains integrity. There are two components that a Vault operator needs to configure before external plugins can be run. @@ -52,16 +52,18 @@ between the time of the SHA check and the time of plugin execution. ### Plugin Catalog The plugin catalog is Vault's list of approved plugins. The catalog is stored in Vault's barrier and can only be updated by a vault user with sudo permissions. -Upon adding a new plugin the SHA256 sum of the executable and the command that -should be used to run the plugin must be provided. The catalog will make sure -the executable referenced in the command exists in the plugin directory. When -added to the catalog the plugin is not automatically executed, it instead -becomes visible to backends and can be executed by them. +Upon adding a new plugin the plugin name, SHA256 sum of the executable, and the +command that should be used to run the plugin must be provided. The catalog will +make sure the executable referenced in the command exists in the plugin +directory. When added to the catalog the plugin is not automatically executed, +it instead becomes visible to backends and can be executed by them. ### Plugin Execution -When a backend executes a plugin it first checks the executable's SHA256 sum -against the one configured in the plugin catalog. Like Vault, plugins support -the use of mlock when availible. +When a backend wants to run a plugin, it first looks up the plugin, by name, in +the catalog. It then checks the executable's SHA256 sum against the one +configured in the plugin catalog. Finally vault runs the command configured in +the catalog, sending along the JWT formatted response wrapping token and mlock +settings (like Vault, plugins support the use of mlock when availible). # Plugin Development Because Vault communicates to plugins over a RPC interface, you can build and From 712cacaf4d43af4315991c798d5f128ee12c3115 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 2 May 2017 16:26:32 -0400 Subject: [PATCH 125/162] Add website skeleton --- .../docs/secrets/databases/cassandra.html.md | 9 +++++++ .../docs/secrets/databases/index.html.md | 11 ++++++++ .../docs/secrets/databases/mssql.html.md | 9 +++++++ .../secrets/databases/mysql-maria.html.md | 9 +++++++ .../docs/secrets/databases/postgresql.html.md | 9 +++++++ website/source/layouts/docs.erb | 26 ++++++++++++++++--- 6 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 website/source/docs/secrets/databases/cassandra.html.md create mode 100644 website/source/docs/secrets/databases/index.html.md create mode 100644 website/source/docs/secrets/databases/mssql.html.md create mode 100644 website/source/docs/secrets/databases/mysql-maria.html.md create mode 100644 website/source/docs/secrets/databases/postgresql.html.md diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md new file mode 100644 index 000000000..012e7db5b --- /dev/null +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "Cassandra Database Plugin" +sidebar_current: "docs-secrets-databases-cassandra" +description: |- + The Cassandra plugin for Vault's Database backend generates database credentials to access Cassandra. +--- + +# Cassandra Database Plugin diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md new file mode 100644 index 000000000..20f7bbed2 --- /dev/null +++ b/website/source/docs/secrets/databases/index.html.md @@ -0,0 +1,11 @@ +--- +layout: "docs" +page_title: "Databases" +sidebar_current: "docs-secrets-databases" +description: |- + Top page for database secret backend information +--- + +# Databases + +Something diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md new file mode 100644 index 000000000..32ecf7775 --- /dev/null +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "MSSQL Database Plugin" +sidebar_current: "docs-secrets-databases-mssql" +description: |- + The MSSQL plugin for Vault's Database backend generates database credentials to access Microsoft SQL Server. +--- + +# MSSQL Database Plugin diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md new file mode 100644 index 000000000..1ee601dbd --- /dev/null +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "MySQL/MariaDB Database Plugin" +sidebar_current: "docs-secrets-databases-mysql-maria" +description: |- + The MySQL/MariaDB plugin for Vault's Database backend generates database credentials to access MySQL and MariaDB servers. +--- + +# MySQL/MariaDB Database Plugin diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md new file mode 100644 index 000000000..5de340043 --- /dev/null +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "PostgreSQL Database Plugin" +sidebar_current: "docs-secrets-databases-postgresql" +description: |- + The PostgreSQL plugin for Vault's Database backend generates database credentials to access PostgreSQL. +--- + +# PostgreSQL Database Plugin diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index 32e2a7e7a..a6afd9c2d 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -208,7 +208,7 @@ > - Cassandra + Cassandra (Deprecated) > @@ -219,6 +219,24 @@ Cubbyhole + > + Databases (Beta) + + + > Generic @@ -228,11 +246,11 @@ > - MSSQL + MSSQL (Deprecated) > - MySQL + MySQL (Deprecated) > @@ -240,7 +258,7 @@ > - PostgreSQL + PostgreSQL (Deprecated) > From 29d9b831d3a59dc9154e6b4c7df3983a0477bd14 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:40:11 -0700 Subject: [PATCH 126/162] Update the api for serving plugins and provide a utility to pass TLS data for commuinicating with the vault process --- builtin/logical/database/backend_test.go | 68 ++++-- .../logical/database/dbplugin/plugin_test.go | 55 +++-- builtin/logical/database/dbplugin/server.go | 13 +- helper/pluginutil/runner.go | 42 ++++ helper/pluginutil/tls.go | 199 +++++++++--------- .../cassandra-database-plugin/main.go | 7 +- plugins/database/cassandra/cassandra.go | 6 +- .../mssql/mssql-database-plugin/main.go | 7 +- plugins/database/mssql/mssql.go | 6 +- .../mysql/mysql-database-plugin/main.go | 7 +- plugins/database/mysql/mysql.go | 6 +- .../postgresql-database-plugin/main.go | 7 +- plugins/database/postgresql/postgresql.go | 6 +- plugins/serve.go | 31 +++ vault/testing.go | 3 + 15 files changed, 310 insertions(+), 153 deletions(-) create mode 100644 plugins/serve.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 08317cbdc..70ec22ee2 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -4,12 +4,13 @@ import ( "database/sql" "fmt" "log" - "net" + stdhttp "net/http" "os" "reflect" "sync" "testing" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -77,13 +78,30 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac return } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { - core, _, token, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "database": Factory, + }, + } - return core, ln, sys, token + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) + + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain") + + return cores, sys } func TestBackend_PluginMain(t *testing.T) { @@ -91,14 +109,20 @@ func TestBackend_PluginMain(t *testing.T) { return } - postgresql.Run() + err := postgresql.Run(&api.TLSConfig{Insecure: true}) + if err != nil { + t.Fatal(err) + } + t.Fatal("We shouldn't get here") } func TestBackend_config_connection(t *testing.T) { var resp *logical.Response var err error - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -147,8 +171,10 @@ func TestBackend_config_connection(t *testing.T) { } func TestBackend_basic(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -238,8 +264,10 @@ func TestBackend_basic(t *testing.T) { } func TestBackend_connectionCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -383,8 +411,10 @@ func TestBackend_connectionCrud(t *testing.T) { } func TestBackend_roleCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -493,8 +523,10 @@ func TestBackend_roleCrud(t *testing.T) { } } func TestBackend_allowedRoles(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 1587ba24a..c38d85ed3 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -2,15 +2,17 @@ package dbplugin_test import ( "errors" - "net" + stdhttp "net/http" "os" "testing" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/vault" log "github.com/mgutz/logxi/v1" ) @@ -72,13 +74,26 @@ func (m *mockPlugin) Close() error { return nil } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { - core, _, _, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{} - return core, ln, sys + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) + + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main") + + return cores, sys } // This is not an actual test case, it's a helper function that will be executed @@ -92,12 +107,14 @@ func TestPlugin_Main(t *testing.T) { users: make(map[string][]string), } - dbplugin.NewPluginServer(plugin) + plugins.Serve(plugin, &api.TLSConfig{Insecure: true}) } func TestPlugin_Initialize(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -120,8 +137,10 @@ func TestPlugin_Initialize(t *testing.T) { } func TestPlugin_CreateUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -155,8 +174,10 @@ func TestPlugin_CreateUser(t *testing.T) { } func TestPlugin_RenewUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -184,8 +205,10 @@ func TestPlugin_RenewUser(t *testing.T) { } func TestPlugin_RevokeUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 32c377e13..9546d092c 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,16 +1,15 @@ package dbplugin import ( - "fmt" + "crypto/tls" "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/pluginutil" ) // Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func Serve(db Database) { +func Serve(db Database, tlsProvider func() (*tls.Config, error)) { dbPlugin := &DatabasePlugin{ impl: db, } @@ -20,16 +19,10 @@ func Serve(db Database) { "database": dbPlugin, } - err := pluginutil.OptionallyEnableMlock() - if err != nil { - fmt.Println(err) - return - } - plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - TLSProvider: pluginutil.VaultPluginTLSProvider, + TLSProvider: tlsProvider, }) } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 0617f7624..91439a3b8 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -2,11 +2,13 @@ package pluginutil import ( "crypto/sha256" + "flag" "fmt" "os/exec" "time" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/wrapping" ) @@ -87,3 +89,43 @@ func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugi return client, nil } + +type APIClientMeta struct { + // These are set by the command line flags. + flagCACert string + flagCAPath string + flagClientCert string + flagClientKey string + flagInsecure bool +} + +func (f *APIClientMeta) FlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("tls settings", flag.ContinueOnError) + + fs.StringVar(&f.flagCACert, "ca-cert", "", "") + fs.StringVar(&f.flagCAPath, "ca-path", "", "") + fs.StringVar(&f.flagClientCert, "client-cert", "", "") + fs.StringVar(&f.flagClientKey, "client-key", "", "") + fs.BoolVar(&f.flagInsecure, "insecure", false, "") + fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") + + return fs +} + +func (f *APIClientMeta) GetTLSConfig() *api.TLSConfig { + // If we need custom TLS configuration, then set it + if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure { + t := &api.TLSConfig{ + CACert: f.flagCACert, + CAPath: f.flagCAPath, + ClientCert: f.flagClientCert, + ClientKey: f.flagClientKey, + TLSServerName: "", + Insecure: f.flagInsecure, + } + + return t + } + + return nil +} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 05804a33b..b355079d6 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -116,109 +116,114 @@ func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) ( // VaultPluginTLSProvider is run inside a plugin and retrives the response // wrapped TLS certificate from vault. It returns a configured TLS Config. -func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv(PluginUnwrapTokenEnv) +func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + unwrapToken := os.Getenv(PluginUnwrapTokenEnv) - // Ensure unwrap token is a JWT - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } + // Ensure unwrap token is a JWT + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } - // Parse the JWT and retrieve the vault address - wt, err := jws.ParseJWT([]byte(unwrapToken)) - if err != nil { - return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) - } - if wt == nil { - return nil, errors.New("nil decoded token") - } + // Parse the JWT and retrieve the vault address + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } - addrRaw := wt.Claims().Get("addr") - if addrRaw == nil { - return nil, errors.New("decoded token does not contain primary cluster address") - } - vaultAddr, ok := addrRaw.(string) - if !ok { - return nil, errors.New("decoded token's address not valid") - } - if vaultAddr == "" { - return nil, errors.New(`no address for the vault found`) - } + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } - // Sanity check the value - if _, err := url.Parse(vaultAddr); err != nil { - return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) - } + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } - // Unwrap the token - clientConf := api.DefaultConfig() - clientConf.Address = vaultAddr - client, err := api.NewClient(clientConf) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } + // Unwrap the token + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + if apiTLSConfig != nil { + clientConf.ConfigureTLS(apiTLSConfig) + } + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during api client creation: {{err}}", err) + } - secret, err := client.Logical().Unwrap(unwrapToken) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - if secret == nil { - return nil, errors.New("error during token unwrap request secret is nil") - } + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } - // Retrieve and parse the server's certificate - serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") + // Retrieve and parse the server's certificate + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Retrieve and parse the server's private key + serverKeyB64, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Add CA cert to the cert pool + caCertPool := x509.NewCertPool() + caCertPool.AddCert(serverCert) + + // Build a certificate object out of the server's cert and private key. + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil } - - serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Retrieve and parse the server's private key - serverKeyB64, ok := secret.Data["ServerKey"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Add CA cert to the cert pool - caCertPool := x509.NewCertPool() - caCertPool.AddCert(serverCert) - - // Build a certificate object out of the server's cert and private key. - cert := tls.Certificate{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverKey, - Leaf: serverCert, - } - - // Setup TLS config - tlsConfig := &tls.Config{ - ClientCAs: caCertPool, - RootCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - // TLS 1.2 minimum - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{cert}, - } - tlsConfig.BuildNameToCertificate() - - return tlsConfig, nil } diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go index 79f0e0dbe..bb3f44142 100644 --- a/plugins/database/cassandra/cassandra-database-plugin/main.go +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/cassandra" ) func main() { - err := cassandra.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := cassandra.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index bf1cbab92..60e445ff6 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -6,8 +6,10 @@ import ( "github.com/gocql/gocql" multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -41,13 +43,13 @@ func New() (interface{}, error) { } // Run instantiates a Cassandra object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.NewPluginServer(dbType.(*Cassandra)) + plugins.Serve(dbType.(*Cassandra), apiTLSConfig) return nil } diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go index ead1cf842..d52fd13db 100644 --- a/plugins/database/mssql/mssql-database-plugin/main.go +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mssql" ) func main() { - err := mssql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mssql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index d82efce6f..9b22aa87c 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -39,13 +41,13 @@ func New() (interface{}, error) { } // Run instantiates a MSSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MSSQL)) + plugins.Serve(dbType.(*MSSQL), apiTLSConfig) return nil } diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go index c0ec75c9c..a9389f504 100644 --- a/plugins/database/mysql/mysql-database-plugin/main.go +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mysql" ) func main() { - err := mysql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mysql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 7eb680759..7a44d7341 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -5,8 +5,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -42,13 +44,13 @@ func New() (interface{}, error) { } // Run instantiates a MySQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MySQL)) + plugins.Serve(dbType.(*MySQL), apiTLSConfig) return nil } diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index 9b9b813c4..e6acb0584 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/postgresql" ) func main() { - err := postgresql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := postgresql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index bc5b14544..d60ef8bbe 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -35,13 +37,13 @@ func New() (interface{}, error) { } // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*PostgreSQL)) + plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig) return nil } diff --git a/plugins/serve.go b/plugins/serve.go new file mode 100644 index 000000000..263b301f7 --- /dev/null +++ b/plugins/serve.go @@ -0,0 +1,31 @@ +package plugins + +import ( + "fmt" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// Serve is used to start a plugin's RPC server. It takes an interface that must +// implement a known plugin interface to vault and an optional api.TLSConfig for +// use during the inital unwrap request to vault. The api config is particulary +// useful when vault is setup to require client cert checking. +func Serve(plugin interface{}, tlsConfig *api.TLSConfig) { + tlsProvider := pluginutil.VaultPluginTLSProvider(tlsConfig) + + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return + } + + switch p := plugin.(type) { + case dbplugin.Database: + dbplugin.Serve(p, tlsProvider) + default: + fmt.Println("Unsuported plugin type") + } + +} diff --git a/vault/testing.go b/vault/testing.go index b2fe36b33..36bbb1276 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -790,6 +790,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c1.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -799,6 +800,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -808,6 +810,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr // // Clustering setup From 5e0c03415b22e72456f5dc74e6ea55e91bd9bb06 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:44:14 -0700 Subject: [PATCH 127/162] Don't need to explictly set redirectAddrs --- vault/testing.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/vault/testing.go b/vault/testing.go index 36bbb1276..b2fe36b33 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -790,7 +790,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c1.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -800,7 +799,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c2.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -810,7 +808,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c2.redirectAddr = coreConfig.RedirectAddr // // Clustering setup From f644c34c5b2957046422d8e0f54f5925424eb36c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:52:48 -0700 Subject: [PATCH 128/162] Remove unused TestCoreUnsealedWithListener function --- vault/testing.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vault/testing.go b/vault/testing.go index b2fe36b33..a8c1f16bd 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -234,18 +234,6 @@ func TestCoreUnsealedBackend(t testing.TB, backend physical.Backend) (*Core, [][ return core, keys, token } -func TestCoreUnsealedWithListener(t testing.TB) (*Core, [][]byte, string, net.Listener) { - core, keys, token := TestCoreUnsealed(t) - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %s", err) - } - addr := "http://" + ln.Addr().String() - core.redirectAddr = addr - - return core, keys, token, ln -} - func testTokenStore(t testing.TB, c *Core) *TokenStore { me := &MountEntry{ Table: credentialTableType, From fdf045b3bda3acca80a0c228b3d53e52d5cf9d81 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 15:59:08 -0700 Subject: [PATCH 129/162] Fix a few PR comments --- helper/pluginutil/runner.go | 1 - plugins/serve.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 91439a3b8..9dbe5c51b 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -106,7 +106,6 @@ func (f *APIClientMeta) FlagSet() *flag.FlagSet { fs.StringVar(&f.flagCAPath, "ca-path", "", "") fs.StringVar(&f.flagClientCert, "client-cert", "", "") fs.StringVar(&f.flagClientKey, "client-key", "", "") - fs.BoolVar(&f.flagInsecure, "insecure", false, "") fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") return fs diff --git a/plugins/serve.go b/plugins/serve.go index 263b301f7..a40fc5b14 100644 --- a/plugins/serve.go +++ b/plugins/serve.go @@ -25,7 +25,7 @@ func Serve(plugin interface{}, tlsConfig *api.TLSConfig) { case dbplugin.Database: dbplugin.Serve(p, tlsProvider) default: - fmt.Println("Unsuported plugin type") + fmt.Println("Unsupported plugin type") } } From 20994c1247164822dfe3e66fce0fe266027831d9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 16:20:07 -0700 Subject: [PATCH 130/162] Fix wording in docs --- website/source/docs/internals/plugins.html.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index a3baafff0..f1a720efb 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -35,7 +35,8 @@ private key for it to use to talk to the original vault process. ## Plugin Registration An important consideration of Vault's plugin system is to ensure the plugin invoked by vault is authentic and maintains integrity. There are two components -that a Vault operator needs to configure before external plugins can be run. +that a Vault operator needs to configure before external plugins can be run, the +plugin directory and the plugin catalog entry. ### Plugin Directory The plugin directory is a configuration option of Vault, and can be specified in @@ -52,7 +53,7 @@ between the time of the SHA check and the time of plugin execution. ### Plugin Catalog The plugin catalog is Vault's list of approved plugins. The catalog is stored in Vault's barrier and can only be updated by a vault user with sudo permissions. -Upon adding a new plugin the plugin name, SHA256 sum of the executable, and the +Upon adding a new plugin, the plugin name, SHA256 sum of the executable, and the command that should be used to run the plugin must be provided. The catalog will make sure the executable referenced in the command exists in the plugin directory. When added to the catalog the plugin is not automatically executed, From b60ff2048da1ce6737a91415e6db78f0c2e008ac Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 17:04:49 -0700 Subject: [PATCH 131/162] Update docs and add cassandra as a builtin plugin --- helper/builtinplugins/builtin.go | 2 ++ website/source/docs/internals/plugins.html.md | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index c20a92603..3dec8588b 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,6 +1,7 @@ package builtinplugins import ( + "github.com/hashicorp/vault/plugins/database/cassandra" "github.com/hashicorp/vault/plugins/database/mssql" "github.com/hashicorp/vault/plugins/database/mysql" "github.com/hashicorp/vault/plugins/database/postgresql" @@ -12,6 +13,7 @@ var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ "mysql-database-plugin": mysql.New, "postgresql-database-plugin": postgresql.New, "mssql-database-plugin": mssql.New, + "cassandra-database-plugin": cassandra.New, } func Get(name string) (BuiltinFactory, bool) { diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index f1a720efb..600bc034e 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -93,11 +93,11 @@ backend running the plugin. package main import ( - plugin "github.com/hashicorp/vault/builtin/logcial/database/dbplugin" + "github.com/hashicorp/vault/plugins" ) func main() { - plugin.Serve(new(MyPlugin)) + plugins.Serve(new(MyPlugin), nil) } ``` From 7ae8f02f4be6ee3be64323933ab1c5cb57beaabf Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 17:19:49 -0700 Subject: [PATCH 132/162] Only wrap in tracing middleware if the logger is set to trace level --- .../database/dbplugin/databasemiddleware.go | 50 ++++++++----------- builtin/logical/database/dbplugin/plugin.go | 10 ++-- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 13591e516..83f57ef87 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -23,57 +23,47 @@ func (mw *databaseTracingMiddleware) Type() (string, error) { } func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) return mw.next.CreateUser(statements, usernamePrefix, expiration) } func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) - } + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) return mw.next.RenewUser(statements, username, expiration) } func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) return mw.next.RevokeUser(statements, username) } func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) return mw.next.Close() } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 941f7aa04..bc63594ae 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -73,10 +73,12 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. } // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: typeStr, - logger: logger, + if logger.IsTrace() { + db = &databaseTracingMiddleware{ + next: db, + typeStr: typeStr, + logger: logger, + } } return db, nil From 50ac77be5171ac7153fa7c08d0ab9aace5a56cdd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 22:24:31 -0700 Subject: [PATCH 133/162] Update docs for the database backend and it's plugins --- .../docs/secrets/databases/cassandra.html.md | 53 +++++++++++ .../docs/secrets/databases/index.html.md | 88 ++++++++++++++++++- .../docs/secrets/databases/mssql.html.md | 51 +++++++++++ .../secrets/databases/mysql-maria.html.md | 49 +++++++++++ .../docs/secrets/databases/postgresql.html.md | 51 +++++++++++ 5 files changed, 291 insertions(+), 1 deletion(-) diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 012e7db5b..99d3d3bf9 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -7,3 +7,56 @@ description: |- --- # Cassandra Database Plugin + +Name: `cassandra-database-plugin` + +The Cassandra Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the Cassandra database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a cassandra connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +cassandra configuration: + +``` +$ vault write database/config/cassandra \ + plugin_name=cassandra-database-plugin \ + allowed_roles="readonly" \ + hosts=localhost \ + username=cassandra \ + password=cassandra + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the cassandra connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=cassandra \ + creation_statements="CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER; \ + GRANT SELECT ON ALL KEYSPACES TO {{username}};" \ + default_ttl="1h" \ + max_ttl="24h" + + +Success! Data written to: database/roles/readonly +``` + +This role can be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [Cassandra database +plugin API](/api/secret/database/cassandra.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md index 20f7bbed2..cf366d9c9 100644 --- a/website/source/docs/secrets/databases/index.html.md +++ b/website/source/docs/secrets/databases/index.html.md @@ -8,4 +8,90 @@ description: |- # Databases -Something +Name: `Database` + +The Database secret backend for Vault generates database credentials dynamically +based on configured roles. It works with a number of different databases through +a plugin interface. There are a number of builtin database types and an exposed +framework for running custom database types for extendability. This means that +services that need to access a database no longer need to hardcode credentials: +they can request them from Vault, and use Vault's leasing mechanism to more +easily roll keys. + +Additionally, it introduces a new ability: with every service accessing the +database with unique credentials, it makes auditing much easier when +questionable data access is discovered: you can track it down to the specific +instance of a service based on the SQL username. + +Vault makes use of its own internal revocation system to ensure that users +become invalid within a reasonable time of the lease expiring. + +This page will show a quick start for this backend. For detailed documentation +on every path, use vault path-help after mounting the backend. + +## Quick Start + +The first step in using the Database backend is mounting it. + +```text +$ vault mount database +Successfully mounted 'database' at 'database'! +``` + +Next, we must configure this backend to connect to a database. In this example +we will connect to a MySQL database, but the configuration details needed for +other plugin types can be found in their docs pages. This backend can configure +multiple database connections, therefore a name for the connection must be +provide; we'll call this one simply "mysql". + +``` +$ vault write database/config/mysql \ + plugin_name=mysql-database-plugin \ + connection_url="root:mysql@tcp(127.0.0.1:3306)/" \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +The next step is to configure a role. A role is a logical name that maps to a +policy used to generate those credentials. A role needs to be configured with +the database name we created above, and the default/max TTLs. For example, lets +create a "readonly" role: + +``` +$ vault write database/roles/readonly \ + db_name=mysql \ + creation_statements="CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';" \ + default_ttl="1h" \ + max_ttl="24h" +Success! Data written to: database/roles/readonly +``` +By writing to the roles/readonly path we are defining the readonly role. This +role will be created by evaluating the given creation statements. By default, +the {{name}} and {{password}} fields will be populated by the plugin with +dynamically generated values. In other plugins the {{expiration}} field could +also be supported. This SQL statement is creating the named user, and then +granting it SELECT or read-only privileges to tables in the database. More +complex GRANT queries can be used to customize the privileges of the role. +Custom revocation statements could be passed too, but this plugin has a default +statement we can use. + +To generate a new set of credentials, we simply read from that role: + +``` +$ vault read database/creds/readonly +Key Value +--- ----- +lease_id database/creds/readonly/2f6a614c-4aa2-7b19-24b9-ad944a8d4de6 +lease_duration 1h0m0s +lease_renewable true +password 8cab931c-d62e-a73d-60d3-5ee85139cd66 +username v-root-e2978cd0- +``` + +## API + +The Database secret backend has a full HTTP API. Please see the [Database secret +backend API](/api/secret/database/index.html) for more details. + diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 32ecf7775..2d220ddb8 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -7,3 +7,54 @@ description: |- --- # MSSQL Database Plugin + +Name: `mssql-database-plugin` + +The MSSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MSSQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a MSSQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/mssql \ + plugin_name=mssql-database-plugin \ + connection_url='sqlserver://sa:yourStrong(!)Password@localhost:1433' \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the MSSQL connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=mssql \ + creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ + USE AdventureWorks; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; \ + GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can now be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [MSSQL database +plugin API](/api/secret/database/mssql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index 1ee601dbd..bd61cc43b 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -7,3 +7,52 @@ description: |- --- # MySQL/MariaDB Database Plugin + +Name: `mysql-database-plugin` + +The MySQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MySQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a MySQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/mysql \ + plugin_name=mysql-database-plugin \ + connection_url="root:mysql@tcp(127.0.0.1:3306)/" \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the MySQL connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=mysql \ + creation_statements="CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can now be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [MySQL database +plugin API](/api/secret/database/mysql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 5de340043..e5fee10ef 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -7,3 +7,54 @@ description: |- --- # PostgreSQL Database Plugin + +Name: `postgresql-database-plugin` + +The PostgreSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the PostgreSQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a PostgreSQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/postgresql \ + plugin_name=postgresql-database-plugin \ + allowed_roles="readonly" \ + connection_url="postgresql://root:root@localhost:5432/postgres" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the PostgreSQL connection is configured we can add a role. The PostgreSQL +plugin replaces `{{expiration}}` in statements with a formated timestamp: + +``` +$ vault write database/roles/readonly \ + db_name=postgresql \ + creation_statements="CREATE ROLE \"{{name}}\" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \ + GRANT SELECT ON ALL TABLES IN SCHEMA public TO \"{{name}}\";" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [PostgreSQL database +plugin API](/api/secret/database/postgresql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + From 63de72c10f5a217e6868e3a72e3498d5dc30a1f1 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 00:01:28 -0700 Subject: [PATCH 134/162] Add custom plugins docs page --- .../docs/secrets/databases/cassandra.html.md | 2 +- .../docs/secrets/databases/custom.html.md | 120 ++++++++++++++++++ .../docs/secrets/databases/mssql.html.md | 2 +- .../secrets/databases/mysql-maria.html.md | 2 +- .../docs/secrets/databases/postgresql.html.md | 2 +- website/source/layouts/docs.erb | 3 + 6 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 website/source/docs/secrets/databases/custom.html.md diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 99d3d3bf9..b3d87f7ed 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -58,5 +58,5 @@ The full list of configurable options can be seen in the [Cassandra database plugin API](/api/secret/database/cassandra.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/custom.html.md b/website/source/docs/secrets/databases/custom.html.md new file mode 100644 index 000000000..5911e1d80 --- /dev/null +++ b/website/source/docs/secrets/databases/custom.html.md @@ -0,0 +1,120 @@ +--- +layout: "docs" +page_title: "Custom Database Plugins" +sidebar_current: "docs-secrets-databases-custom" +description: |- + Creating custom database plugins for Vault's Database backend to generate credentials for a database. +--- + +# Custom Database Plugins + +The Database backend allows new functionality to be added through a plugin +interface without needing to modify vault's core code. This allows you write +your own code to generate credentials in any database you wish. It also allows +databases that require dynamically linked libraries to be used with vault. + +~> **Advanced topic!** Plugin development is a highly advanced +topic in Vault, and is not required knowledge for day-to-day usage. +If you don't plan on writing any plugins, we recommend not reading +this section of the documentation. + +Please read the [Plugins internals](/docs/internals/plugins.html) docs for more +information about the plugin system before getting started building your +Database plugin. + +## Plugin Interface + +All plugins for the Database backend must implement the same simple interface. + +```go +type Database interface { + Type() (string, error) + CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) + RenewUser(statements Statements, username string, expiration time.Time) error + RevokeUser(statements Statements, username string) error + + Initialize(config map[string]interface{}, verifyConnection bool) error + Close() error +} +``` + +You'll notice the first parameter to a number of those functions is a +`Statements` struct. This struct is used to pass the Role's configured +statements to the plugin on function call. The struct is defined as: + +```go +type Statements struct { + CreationStatements string + RevocationStatements string + RollbackStatements string + RenewStatements string +} +``` + +It is up to your plugin to replace the `{{name}}`, `{{password}}`, and +`{{expiration}}` in these statements with the proper vaules. + +The `Initialize` function is passed a map of keys to values, this data is what the +user specified as the configuration for the plugin. Your plugin should use this +data to make connections to the database. It is also passed a boolean value +specifying whether or not your plugin should return an error if it is unable to +connect to the database. + +## Serving your plugin + +Once your plugin is built you should pass it to vault's `plugins` package by +calling the `Serve` method: + +```go +package main + +import ( + "github.com/hashicorp/vault/plugins" +) + +func main() { + plugins.Serve(new(MyPlugin), nil) +} +``` + +Replacing `MyPlugin` with the actual implementation of your plugin. + +The second parameter to `Serve` takes in an optional vault `api.TLSConfig` for +configuring the plugin to communicate with vault for the initial unwrap call. +This if useful if your vault setup requires client certificate checks. This +config wont be used once the plugin unwraps its own TLS cert and key. + +## Running your plugin + +The above main package, once built, will supply you with a binary of your +plugin. We also recommend if you are planning on distributing your plugin to +build with [gox](https://github.com/mitchellh/gox) for cross platform builds. + +To use your plugin with the Database backend you need to place the binary in the +plugin directory as specified in the [plugin internals](/docs/internals/plugins.html) docs. + +You should now be able to register your plugin into the vault catalog. To do +this your token will need sudo permissions. + +``` +$ vault write sys/plugins/catalog/myplugin-database-plugin \ + sha_256= \ + command="myplugin" +Success! Data written to: sys/plugins/catalog/myplugin-database-plugin +``` + +Now you should be able to configure your plugin like any other: + +``` +$ vault write database/config/myplugin \ + plugin_name=myplugin-database-plugin \ + allowed_roles="readonly" \ + myplugins_connection_details=.... + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + + + + diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 2d220ddb8..0eefe1764 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -56,5 +56,5 @@ The full list of configurable options can be seen in the [MSSQL database plugin API](/api/secret/database/mssql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index bd61cc43b..76ca193fc 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -54,5 +54,5 @@ The full list of configurable options can be seen in the [MySQL database plugin API](/api/secret/database/mysql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index e5fee10ef..81716132f 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -56,5 +56,5 @@ The full list of configurable options can be seen in the [PostgreSQL database plugin API](/api/secret/database/postgresql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index a6afd9c2d..95787e8e1 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -234,6 +234,9 @@ > PostgreSQL + > + Custom + From dbb5b38e0d804dd48b8455ec20492f6069bb7a12 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 02:13:07 -0700 Subject: [PATCH 135/162] Add API docs --- .../api/secret/databases/cassandra.html.md | 96 +++++ .../source/api/secret/databases/index.html.md | 342 ++++++++++++++++++ .../source/api/secret/databases/mssql.html.md | 60 +++ .../api/secret/databases/mysql-maria.html.md | 60 +++ .../api/secret/databases/postgresql.html.md | 60 +++ .../docs/secrets/databases/custom.html.md | 5 +- website/source/layouts/api.erb | 27 +- 7 files changed, 644 insertions(+), 6 deletions(-) create mode 100644 website/source/api/secret/databases/cassandra.html.md create mode 100644 website/source/api/secret/databases/index.html.md create mode 100644 website/source/api/secret/databases/mssql.html.md create mode 100644 website/source/api/secret/databases/mysql-maria.html.md create mode 100644 website/source/api/secret/databases/postgresql.html.md diff --git a/website/source/api/secret/databases/cassandra.html.md b/website/source/api/secret/databases/cassandra.html.md new file mode 100644 index 000000000..5e2b5a836 --- /dev/null +++ b/website/source/api/secret/databases/cassandra.html.md @@ -0,0 +1,96 @@ +--- +layout: "api" +page_title: "Cassandra Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-cassandra-maria" +description: |- + The Cassandra plugin for Vault's Database backend generates database credentials to access Cassandra servers. +--- + +# Cassandra Database Plugin HTTP API + +The Cassandra Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the Cassandra database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `hosts` `(string: )` – Specifies a set of comma-delineated Cassandra + hosts to connect to. + +- `username` `(string: )` – Specifies the username to use for + superuser access. + +- `password` `(string: )` – Specifies the password corresponding to + the given username. + +- `tls` `(bool: true)` – Specifies whether to use TLS when connecting to + Cassandra. + +- `insecure_tls` `(bool: false)` – Specifies whether to skip verification of the + server certificate when using TLS. + +- `pem_bundle` `(string: "")` – Specifies concatenated PEM blocks containing a + certificate and private key; a certificate, private key, and issuing CA + certificate; or just a CA certificate. + +- `pem_json` `(string: "")` – Specifies JSON containing a certificate and + private key; a certificate, private key, and issuing CA certificate; or just a + CA certificate. For convenience format is the same as the output of the + `issue` command from the `pki` backend; see + [the pki documentation](/docs/secrets/pki/index.html). + +- `protocol_version` `(int: 2)` – Specifies the CQL protocol version to use. + +- `connect_timeout` `(string: "5s")` – Specifies the connection timeout to use. + +TLS works as follows: + +- If `tls` is set to true, the connection will use TLS; this happens + automatically if `pem_bundle`, `pem_json`, or `insecure_tls` is set + +- If `insecure_tls` is set to true, the connection will not perform verification + of the server certificate; this also sets `tls` to true + +- If only `issuing_ca` is set in `pem_json`, or the only certificate in + `pem_bundle` is a CA certificate, the given CA certificate will be used for + server certificate verification; otherwise the system CA certificates will be + used + +- If `certificate` and `private_key` are set in `pem_bundle` or `pem_json`, + client auth will be turned on for the connection + +`pem_bundle` should be a PEM-concatenated bundle of a private key + client +certificate, an issuing CA certificate, or both. `pem_json` should contain the +same information; for convenience, the JSON format is the same as that output by +the issue command from the PKI backend. + +### Sample Payload + +```json +{ + "plugin_name": "cassandra-database-plugin", + "allowed_roles": "readonly", + "hosts": "cassandra1.local", + "username": "user", + "password": "pass" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/cassandra/config/connection +``` diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md new file mode 100644 index 000000000..9e6015648 --- /dev/null +++ b/website/source/api/secret/databases/index.html.md @@ -0,0 +1,342 @@ +--- +layout: "api" +page_title: "Databases - HTTP API" +sidebar_current: "docs-http-secret-databases" +description: |- + Top page for database secret backend information +--- + +# Database Secret Backend HTTP API + +This is the API documentation for the Vault Database secret backend. For +general information about the usage and operation of the Database backend, +please see the +[Vault Database backend documentation](/docs/secrets/database/index.html). + +This documentation assumes the Database backend is mounted at the +`/database` path in Vault. Since it is possible to mount secret backends at +any location, please update your API calls accordingly. + +## Configure Connection + +This endpoint configures the connection string used to communicate with the +desired database. In addition to the parameters listed here, each Database +plugin has additional, database plugin specifig, parameters for this endpoint. +Please read the HTTP API for the plugin you'd wish to configure to see the full +list of additional parameters. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `name` `(string: )` – Specifies the name for this database + connection. This is specified as part of the URL. + +- `plugin_name` `(string: )` - Specifies the name of the plugin to use + for this connection. + +- `verify_connection` `(bool: true)` – Specifies if the connection is verified + during initial configuration. Defaults to true. + +- `allowed_roles` `(slice: [])` - Array or comma separated string of the roles + allowed to use this connection. Defaults to empty (no roles), if contains a + "*" any role can use this connection. + +### Sample Payload + +```json +{ + "plugin_name": "mysql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mysql +``` + +## Read Connection + +This endpoint returns the configuration settings for a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/config/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to read. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request GET \ + https://vault.rocks/v1/database/config/mysql +``` + +### Sample Response + +```json +{ + "data": { + "allowed_roles": [ + "readonly" + ], + "connection_details": { + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/", + }, + "plugin_name": "mysql-database-plugin" + }, +} +``` + +## Delete Connection + +This endpoint deletes a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/database/config/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to delete. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/database/config/mysql +``` + +## Reset Connection + +This endpoint closes a connection and it's underlying plugin and restarts it +with the configuration stored in the barrier. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/reset/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to delete. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + https://vault.rocks/v1/database/reset/mysql +``` + +## Create Role + +This endpoint creates or updates a role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/roles/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to create. This + is specified as part of the URL. + +- `db_name` `(string: )` - The name of the database connection to use + for this role. + +- `default_ttl` `(string: )` - Specifies the TTL for the lease + associated with this role. + +- `max_ttl` `(string: )` - Specifies the maximum TTL for the lease + associated with this role. + +- `creation_statements` `(string: )` – Specifies the database + statements executed to create and configure a user. Must be a + semicolon-separated string, a base64-encoded semicolon-separated string, a + serialized JSON string array, or a base64-encoded serialized JSON string + array. The '{{name}}', '{{password}}' and '{{expiration}}' values will be + substituted. + +- `revocation_statements` `(string: "")` – Specifies the database statements to + be executed to revoke a user. Must be a semicolon-separated string, a + base64-encoded semicolon-separated string, a serialized JSON string array, or + a base64-encoded serialized JSON string array. The '{{name}}' value will be + substituted. + +- `rollback_statements` `(string: "")` – Specifies the database statements to be + executed rollback a create operation in the event of an error. Not every + plugin type will support this functionality. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted. + +- `renew_statements` `(string: "")` – Specifies the database statements to be + executed to renew a user. Not every plugin type will support this + functionality. Must be a semicolon-separated string, a base64-encoded + semicolon-separated string, a serialized JSON string array, or a + base64-encoded serialized JSON string array. The '{{name}}' and + '{{expiration}}` values will be substituted. + + +### Sample Payload + +```json +{ + "db_name": "mysql", + "creation_statements": "CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';", + "default_ttl": "1h", + "max_ttl": "24h" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/roles/my-role +``` + +## Read Role + +This endpoint queries the role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/roles/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to read. This + is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/database/roles/my-role +``` + +### Sample Response + +```json +{ + "data": { + "creation_statements": "CREATE ROLE \"{{name}}\" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT SELECT ON ALL TABLES IN SCHEMA public TO \"{{name}}\";", + "db_name": "mysql", + "default_ttl": 3600, + "max_ttl": 86400, + "renew_statements": "", + "revocation_statements": "", + "rollback_statements": "" + }, +} +``` + +## List Roles + +This endpoint returns a list of available roles. Only the role names are +returned, not any values. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `LIST` | `/database/roles` | `200 application/json` | + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request LIST \ + https://vault.rocks/v1/database/roles +``` + +### Sample Response + +```json +{ + "auth": null, + "data": { + "keys": ["dev", "prod"] + }, + "lease_duration": 2764800, + "lease_id": "", + "renewable": false +} +``` + +## Delete Role + +This endpoint deletes the role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/database/roles/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to delete. This + is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/database/roles/my-role +``` + +## Generate Credentials + +This endpoint generates a new set of dynamic credentials based on the named +role. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/creds/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to create + credentials against. This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/database/creds/my-role +``` + +### Sample Response + +```json +{ + "data": { + "username": "root-1430158508-126", + "password": "132ae3ef-5a64-7499-351e-bfe59f3a2a21" + } +} +``` diff --git a/website/source/api/secret/databases/mssql.html.md b/website/source/api/secret/databases/mssql.html.md new file mode 100644 index 000000000..09893df45 --- /dev/null +++ b/website/source/api/secret/databases/mssql.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "MSSQL Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-mssql-maria" +description: |- + The MSSQL plugin for Vault's Database backend generates database credentials to access MSSQL servers. +--- + +# MSSQL Database Plugin HTTP API + +The MSSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MSSQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the MSSQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "mssql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "sqlserver://sa:yourStrong(!)Password@localhost:1433", + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mssql +``` + diff --git a/website/source/api/secret/databases/mysql-maria.html.md b/website/source/api/secret/databases/mysql-maria.html.md new file mode 100644 index 000000000..981506798 --- /dev/null +++ b/website/source/api/secret/databases/mysql-maria.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "MySQL/MariaDB Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-mysql-maria" +description: |- + The MySQL/MariaDB plugin for Vault's Database backend generates database credentials to access MySQL and MariaDB servers. +--- + +# MySQL/MariaDB Database Plugin HTTP API + +The MySQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MySQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the MySQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "mysql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/" + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mysql +``` + diff --git a/website/source/api/secret/databases/postgresql.html.md b/website/source/api/secret/databases/postgresql.html.md new file mode 100644 index 000000000..5ff6b8022 --- /dev/null +++ b/website/source/api/secret/databases/postgresql.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "PostgreSQL Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-postgresql-maria" +description: |- + The PostgreSQL plugin for Vault's Database backend generates database credentials to access PostgreSQL servers. +--- + +# PostgreSQL Database Plugin HTTP API + +The PostgreSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the PostgreSQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the PostgreSQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "postgresql://root:root@localhost:5432/postgres", + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/postgresql +``` + diff --git a/website/source/docs/secrets/databases/custom.html.md b/website/source/docs/secrets/databases/custom.html.md index 5911e1d80..7d21a19c9 100644 --- a/website/source/docs/secrets/databases/custom.html.md +++ b/website/source/docs/secrets/databases/custom.html.md @@ -11,7 +11,8 @@ description: |- The Database backend allows new functionality to be added through a plugin interface without needing to modify vault's core code. This allows you write your own code to generate credentials in any database you wish. It also allows -databases that require dynamically linked libraries to be used with vault. +databases that require dynamically linked libraries to be used as plugins while +keeping Vault itself statically linked. ~> **Advanced topic!** Plugin development is a highly advanced topic in Vault, and is not required knowledge for day-to-day usage. @@ -81,7 +82,7 @@ Replacing `MyPlugin` with the actual implementation of your plugin. The second parameter to `Serve` takes in an optional vault `api.TLSConfig` for configuring the plugin to communicate with vault for the initial unwrap call. -This if useful if your vault setup requires client certificate checks. This +This is useful if your vault setup requires client certificate checks. This config wont be used once the plugin unwraps its own TLS cert and key. ## Running your plugin diff --git a/website/source/layouts/api.erb b/website/source/layouts/api.erb index c209937bc..ea8e35624 100644 --- a/website/source/layouts/api.erb +++ b/website/source/layouts/api.erb @@ -21,7 +21,7 @@ AWS > - Cassandra + Cassandra (Deprecated) > Consul @@ -29,6 +29,25 @@ > Cubbyhole + + > + Databases (Beta) + + + > Generic @@ -36,16 +55,16 @@ MongoDB > - MSSQL + MSSQL (Deprecated) > - MySQL + MySQL (Deprecated) > PKI > - PostgreSQL + PostgreSQL (Deprecated) > RabbitMQ From e92818e0aeb15b1630a2c7a315f6da727d91cef0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 10:25:12 -0700 Subject: [PATCH 136/162] Upate links in docs --- website/source/api/secret/databases/index.html.md | 2 +- website/source/docs/secrets/databases/cassandra.html.md | 6 +++--- website/source/docs/secrets/databases/index.html.md | 2 +- website/source/docs/secrets/databases/mssql.html.md | 6 +++--- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- website/source/docs/secrets/databases/postgresql.html.md | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md index 9e6015648..f55998ace 100644 --- a/website/source/api/secret/databases/index.html.md +++ b/website/source/api/secret/databases/index.html.md @@ -11,7 +11,7 @@ description: |- This is the API documentation for the Vault Database secret backend. For general information about the usage and operation of the Database backend, please see the -[Vault Database backend documentation](/docs/secrets/database/index.html). +[Vault Database backend documentation](/docs/secrets/databases/index.html). This documentation assumes the Database backend is mounted at the `/database` path in Vault. Since it is possible to mount secret backends at diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index b3d87f7ed..1d8468ad3 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -14,7 +14,7 @@ The Cassandra Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the Cassandra database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -55,8 +55,8 @@ This role can be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [Cassandra database -plugin API](/api/secret/database/cassandra.html) page. +plugin API](/api/secret/databases/cassandra.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md index cf366d9c9..c88699c44 100644 --- a/website/source/docs/secrets/databases/index.html.md +++ b/website/source/docs/secrets/databases/index.html.md @@ -93,5 +93,5 @@ username v-root-e2978cd0- ## API The Database secret backend has a full HTTP API. Please see the [Database secret -backend API](/api/secret/database/index.html) for more details. +backend API](/api/secret/databases/index.html) for more details. diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 0eefe1764..fec8924b3 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -14,7 +14,7 @@ The MSSQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MSSQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -53,8 +53,8 @@ This role can now be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [MSSQL database -plugin API](/api/secret/database/mssql.html) page. +plugin API](/api/secret/databases/mssql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index 76ca193fc..c5eea4b7b 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -14,7 +14,7 @@ The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MySQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -51,8 +51,8 @@ This role can now be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [MySQL database -plugin API](/api/secret/database/mysql.html) page. +plugin API](/api/secret/databases/mysql-maria.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 81716132f..72601e34f 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -14,7 +14,7 @@ The PostgreSQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the PostgreSQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -53,8 +53,8 @@ This role can be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [PostgreSQL database -plugin API](/api/secret/database/postgresql.html) page. +plugin API](/api/secret/databases/postgresql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. From bf29861d4971666e7210f2107e0590bdfcdb96cf Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 11:43:24 -0700 Subject: [PATCH 137/162] Add the plugins catalog API docs --- .../source/api/system/plugins-catalog.html.md | 155 ++++++++++++++++++ website/source/layouts/api.erb | 3 + 2 files changed, 158 insertions(+) create mode 100644 website/source/api/system/plugins-catalog.html.md diff --git a/website/source/api/system/plugins-catalog.html.md b/website/source/api/system/plugins-catalog.html.md new file mode 100644 index 000000000..b95526194 --- /dev/null +++ b/website/source/api/system/plugins-catalog.html.md @@ -0,0 +1,155 @@ +--- +layout: "api" +page_title: "/sys/plugins/catalog - HTTP API" +sidebar_current: "docs-http-system-plugins-catalog" +description: |- + The `/sys/plugins/catalog` endpoint is used to manage plugins. +--- + +# `/sys/plugins/catalog` + +The `/sys/plugins/catalog` endpoint is used to list, register, update, and +remove plugins in Vault's catalog. Plugins must be registered before use, and +once registered backends can use the plugin by querying the catalog. + +## List Plugins + +This endpoint lists the plugins in the catalog. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `LIST` | `/sys/plugins/catalog/` | `200 application/json` | + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request LIST + https://vault.rocks/v1/sys/plugins/catalog +``` + +### Sample Response + +```javascript +{ + "data": { + "keys": [ + "cassandra-database-plugin", + "mssql-database-plugin", + "mysql-database-plugin", + "postgresql-database-plugin" + ] + } +} +``` + +## Register Plugin + +This endpoint registers a new plugin, or updates an existing one with the +supplied name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `PUT` | `/sys/plugins/catalog/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name for this plugin. The name + is what is used to look up plugins in the catalog. This is part of the request + URL. + +- `sha_256` `(string: )` – This is the SHA256 sum of the plugin's + binary. Before a plugin is run it's SHA will be checked against this value, if + they do not match the plugin can not be run. + +- `command` `(string: )` – Specifies the command used to execute the + plugin. This is relative to the plugin directory. e.g. `"myplugin + --my_flag=1"` + +### Sample Payload + +```json +{ + "sha_256": "d130b9a0fbfddef9709d8ff92e5e6053ccd246b78632fc03b8548457026961e9", + "command": "mysql-database-plugin" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request PUT \ + --data @payload.json \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` + +## Read Plugin + +This endpoint returns the configuration data for the plugin with the given name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/sys/plugins/catalog/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the plugin to retrieve. + This is part of the request URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request GET \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` + +### Sample Response + +```javascript +{ + "data": { + "plugin": { + "args": [], + "builtin": false, + "command": "/tmp/vault-plugins/mysql-database-plugin", + "name": "example-plugin", + "sha256": "0TC5oPv93vlwnY/5Ll5gU8zSRreGMvwDuFSEVwJpYek=" + } + } +} +``` +## Remove Plugin from Catalog + +This endpoint removes the plugin with the given name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/sys/plugins/catalog/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the plugin to delete. + This is part of the request URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` diff --git a/website/source/layouts/api.erb b/website/source/layouts/api.erb index ea8e35624..c6e92d026 100644 --- a/website/source/layouts/api.erb +++ b/website/source/layouts/api.erb @@ -120,6 +120,9 @@ > /sys/mounts + > + /sys/plugins/catalog + > /sys/policy From 37bd3ed76e3a225f0cfb91a5bd047b2402c6ca55 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:01:05 -0700 Subject: [PATCH 138/162] Use log to output errors instead of fmt --- plugins/database/cassandra/cassandra-database-plugin/main.go | 4 ++-- plugins/database/mssql/mssql-database-plugin/main.go | 4 ++-- plugins/database/mysql/mysql-database-plugin/main.go | 4 ++-- .../database/postgresql/postgresql-database-plugin/main.go | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go index bb3f44142..c70997897 100644 --- a/plugins/database/cassandra/cassandra-database-plugin/main.go +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := cassandra.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go index d52fd13db..5f05c5dff 100644 --- a/plugins/database/mssql/mssql-database-plugin/main.go +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := mssql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go index a9389f504..249e5afee 100644 --- a/plugins/database/mysql/mysql-database-plugin/main.go +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := mysql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index e6acb0584..ac3cf95a7 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := postgresql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } From cf15c023dfc5723fe2c0e71d29e0513196e5eb24 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:11:30 -0700 Subject: [PATCH 139/162] Use ParseDurationSecond to parse the timeouts in connutil --- plugins/helper/database/connutil/cassandra.go | 40 +++++++++++-------- plugins/helper/database/connutil/sql.go | 15 +++---- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 1babc3cbd..27fb25195 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -11,28 +11,30 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/tlsutil" ) // CassandraConnectionProducer implements ConnectionProducer and provides an // interface for cassandra databases to make connections. type CassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - Initialized bool - Type string - session *gocql.Session + connectTimeout time.Duration + Initialized bool + Type string + session *gocql.Session sync.Mutex } @@ -46,6 +48,11 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve } c.Initialized = true + c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) + if err != nil { + return fmt.Errorf("invalid connect_timeout: %s", err) + } + if verifyConnection { if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) @@ -101,8 +108,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - + clusterConfig.Timeout = c.connectTimeout if c.TLS { var tlsConfig *tls.Config if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 0bfc5f9f6..4a6368560 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -10,19 +10,20 @@ import ( // Import sql drivers _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/helper/parseutil" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" ) // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type SQLConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` Type string - MaxConnectionLifetime time.Duration + maxConnectionLifetime time.Duration Initialized bool db *sql.DB sync.Mutex @@ -51,7 +52,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo c.MaxConnectionLifetimeRaw = "0s" } - c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) if err != nil { return fmt.Errorf("invalid max_connection_lifetime: %s", err) } @@ -110,7 +111,7 @@ func (c *SQLConnectionProducer) Connection() (interface{}, error) { // since the request rate shouldn't be high. c.db.SetMaxOpenConns(c.MaxOpenConnections) c.db.SetMaxIdleConns(c.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + c.db.SetConnMaxLifetime(c.maxConnectionLifetime) return c.db, nil } From 223598c67522e36f137c5fcea816ff57f6cb4085 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:33:56 -0700 Subject: [PATCH 140/162] Add the other mysql plugin types with the correct username length settings --- helper/builtinplugins/builtin.go | 8 ++++- plugins/database/mysql/mysql.go | 50 ++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 3dec8588b..8e6ed22ef 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -10,7 +10,13 @@ import ( type BuiltinFactory func() (interface{}, error) var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ - "mysql-database-plugin": mysql.New, + // These four plugins all use the same mysql implementation but with + // different username settings passed by the constructor. + "mysql-database-plugin": mysql.New(mysql.DisplayNameLen, mysql.UsernameLen), + "aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-legacy-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "postgresql-database-plugin": postgresql.New, "mssql-database-plugin": mssql.New, "cassandra-database-plugin": cassandra.New, diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 7a44d7341..b875af520 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -14,11 +14,20 @@ import ( "github.com/hashicorp/vault/plugins/helper/database/dbutil" ) -const defaultMysqlRevocationStmts = ` - REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; - DROP USER '{{name}}'@'%' -` -const mySQLTypeName = "mysql" +const ( + defaultMysqlRevocationStmts = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' + ` + mySQLTypeName = "mysql" +) + +var ( + DisplayNameLen int = 10 + LegacyDisplayNameLen int = 4 + UsernameLen int = 32 + LegacyUsernameLen int = 16 +) type MySQL struct { connutil.ConnectionProducer @@ -26,26 +35,29 @@ type MySQL struct { } // New implements builtinplugins.BuiltinFactory -func New() (interface{}, error) { - connProducer := &connutil.SQLConnectionProducer{} - connProducer.Type = mySQLTypeName +func New(displayLen, usernameLen int) func() (interface{}, error) { + return func() (interface{}, error) { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: displayLen, + UsernameLen: usernameLen, + } + + dbType := &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType, nil } - - dbType := &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run(apiTLSConfig *api.TLSConfig) error { - dbType, err := New() + f := New(DisplayNameLen, UsernameLen) + dbType, err := f() if err != nil { return err } From 015e63164b5c7f397480eba5caa451bdae42e6eb Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:36:16 -0700 Subject: [PATCH 141/162] Fix mysql plugin tests --- plugins/database/mysql/mysql_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index c86f9c2f6..72dbd8156 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -66,7 +66,8 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) @@ -93,7 +94,8 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) @@ -129,7 +131,8 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) From b1a5f45d2c60307961fb98870bbabfae2e430aa9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:45:27 -0700 Subject: [PATCH 142/162] Fix parsing the connection duration when it's nil --- plugins/helper/database/connutil/cassandra.go | 3 +++ plugins/helper/database/connutil/sql.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 27fb25195..958bef201 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -48,6 +48,9 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve } c.Initialized = true + if c.ConnectTimeoutRaw == nil { + c.ConnectTimeoutRaw = "0s" + } c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) if err != nil { return fmt.Errorf("invalid connect_timeout: %s", err) diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 4a6368560..5067e10d7 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -48,7 +48,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo if c.MaxIdleConnections > c.MaxOpenConnections { c.MaxIdleConnections = c.MaxOpenConnections } - if c.MaxConnectionLifetimeRaw == "" { + if c.MaxConnectionLifetimeRaw == nil { c.MaxConnectionLifetimeRaw = "0s" } From 9faf234869732d37b8363d23492f745cdf7dd1bd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 15:36:49 -0700 Subject: [PATCH 143/162] Fix the TLS functionality in cassandra plugin --- plugins/helper/database/connutil/cassandra.go | 70 +++++++++++++++---- plugins/helper/database/connutil/sql.go | 4 ++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 958bef201..869c39e3b 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -23,18 +23,21 @@ type CassandraConnectionProducer struct { Password string `json:"password" structs:"password" mapstructure:"password"` TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"` + PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"` connectTimeout time.Duration - Initialized bool - Type string - session *gocql.Session + certificate string + privateKey string + issuingCA string + + Initialized bool + Type string + session *gocql.Session sync.Mutex } @@ -56,6 +59,47 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve return fmt.Errorf("invalid connect_timeout: %s", err) } + switch { + case len(c.Hosts) == 0: + return fmt.Errorf("hosts cannot be empty") + case len(c.Username) == 0: + return fmt.Errorf("username cannot be empty") + case len(c.Password) == 0: + return fmt.Errorf("password cannot be empty") + } + + var certBundle *certutil.CertBundle + var parsedCertBundle *certutil.ParsedCertBundle + switch { + case len(c.PemJSON) != 0: + parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) + if err != nil { + return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err) + } + certBundle, err = parsedCertBundle.ToCertBundle() + if err != nil { + return fmt.Errorf("Error marshaling PEM information: %s", err) + } + c.certificate = certBundle.Certificate + c.privateKey = certBundle.PrivateKey + c.issuingCA = certBundle.IssuingCA + c.TLS = true + + case len(c.PemBundle) != 0: + parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) + if err != nil { + return fmt.Errorf("Error parsing the given PEM information: %s", err) + } + certBundle, err = parsedCertBundle.ToCertBundle() + if err != nil { + return fmt.Errorf("Error marshaling PEM information: %s", err) + } + c.certificate = certBundle.Certificate + c.privateKey = certBundle.PrivateKey + c.issuingCA = certBundle.IssuingCA + c.TLS = true + } + if verifyConnection { if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) @@ -114,18 +158,18 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { clusterConfig.Timeout = c.connectTimeout if c.TLS { var tlsConfig *tls.Config - if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { - if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { + if len(c.certificate) > 0 || len(c.issuingCA) > 0 { + if len(c.certificate) > 0 && len(c.privateKey) == 0 { return nil, fmt.Errorf("found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} - if len(c.Certificate) > 0 { - certBundle.Certificate = c.Certificate - certBundle.PrivateKey = c.PrivateKey + if len(c.certificate) > 0 { + certBundle.Certificate = c.certificate + certBundle.PrivateKey = c.privateKey } - if len(c.IssuingCA) > 0 { - certBundle.IssuingCA = c.IssuingCA + if len(c.issuingCA) > 0 { + certBundle.IssuingCA = c.issuingCA } parsedCertBundle, err := certBundle.ToParsedCertBundle() diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 5067e10d7..04269798f 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -38,6 +38,10 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo return err } + if len(c.ConnectionURL) == 0 { + return fmt.Errorf("connection_url cannot be empty") + } + if c.MaxOpenConnections == 0 { c.MaxOpenConnections = 2 } From cbcb8635a474ba807cdabd61d521ad13cccb6450 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 16:34:09 -0700 Subject: [PATCH 144/162] Update databse backend tests to use the APIClientMeta for the plugin conns --- builtin/logical/database/backend_test.go | 23 ++++++++++++++++--- .../logical/database/dbplugin/plugin_test.go | 9 ++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 70ec22ee2..27c20d332 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "io/ioutil" "log" stdhttp "net/http" "os" @@ -10,7 +11,6 @@ import ( "sync" "testing" - "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -109,11 +109,28 @@ func TestBackend_PluginMain(t *testing.T) { return } - err := postgresql.Run(&api.TLSConfig{Insecure: true}) + content := []byte(vault.TestClusterCACert) + tmpfile, err := ioutil.TempFile("", "example") if err != nil { t.Fatal(err) } - t.Fatal("We shouldn't get here") + + defer os.Remove(tmpfile.Name()) // clean up + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + args := []string{"--ca-cert=" + tmpfile.Name()} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + postgresql.Run(apiClientMeta.GetTLSConfig()) } func TestBackend_config_connection(t *testing.T) { diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index c38d85ed3..c95e119e0 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -107,7 +106,13 @@ func TestPlugin_Main(t *testing.T) { users: make(map[string][]string), } - plugins.Serve(plugin, &api.TLSConfig{Insecure: true}) + args := []string{"--tls-skip-verify=true"} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) } func TestPlugin_Initialize(t *testing.T) { From 0875e78a136e1e4c9091ee85e346213a1d7089a5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 17:37:34 -0700 Subject: [PATCH 145/162] Feedback from PR --- builtin/logical/database/backend.go | 8 ++--- .../database/path_config_connection.go | 4 +-- helper/pluginutil/runner.go | 6 ++-- helper/pluginutil/tls.go | 30 +++++++++---------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index da8c8384a..3d1502805 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -54,7 +54,7 @@ type databaseBackend struct { sync.RWMutex } -// resetAllDBs closes all connections from all database types +// closeAllDBs closes all connections from all database types func (b *databaseBackend) closeAllDBs() { b.Lock() defer b.Unlock() @@ -120,8 +120,8 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab return &config, nil } -func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { - entry, err := s.Get("role/" + n) +func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) { + entry, err := s.Get("role/" + roleName) if err != nil { return nil, err } @@ -170,7 +170,7 @@ func (b *databaseBackend) closeIfShutdown(name string, err error) { const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: -cassandra, msslq, mysql, postgres +cassandra, mssql, mysql, postgres After mounting this backend, configure it using the endpoints within the "database/config/" path. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f37674285..e84212bb8 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -115,7 +115,7 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { } } -// pathConnectionRead reads out the connection configuration +// connectionReadHandler reads out the connection configuration func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) @@ -248,7 +248,7 @@ database. This path runs the provided plugin name and passes the configured connection details to the plugin. See the documentation for the plugin specified for a full list of accepted connection details. -In addition to the database specific connection details, this endpoing also +In addition to the database specific connection details, this endpoint also accepts: * "plugin_name" (required) - The name of a builtin or previously registered diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 9dbe5c51b..4b25ba16b 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -47,20 +47,20 @@ type PluginRunner struct { // plugin. func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate - certBytes, key, err := GenerateCert() + certBytes, key, err := generateCert() if err != nil { return nil, err } // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := CreateClientTLSConfig(certBytes, key) + clientTLSConfig, err := createClientTLSConfig(certBytes, key) if err != nil { return nil, err } // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := WrapServerConfig(wrapper, certBytes, key) + wrapToken, err := wrapServerConfig(wrapper, certBytes, key) if err != nil { return nil, err } diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index b355079d6..1a7fbe783 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -10,17 +10,15 @@ import ( "encoding/base64" "errors" "fmt" - "math/big" - mathrand "math/rand" "net/url" "os" - "strings" "time" "github.com/SermoDigital/jose/jws" "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/certutil" ) var ( @@ -29,9 +27,9 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -// generateSignedCert is used internally to create certificates for the plugin -// client and server. These certs are signed by the given CA Cert and Key. -func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { +// generateCert is used internally to create certificates for the plugin +// client and server. +func generateCert() ([]byte, *ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { return nil, nil, err @@ -42,6 +40,11 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { return nil, nil, err } + sn, err := certutil.GenerateSerialNumber() + if err != nil { + return nil, nil, err + } + template := &x509.Certificate{ Subject: pkix.Name{ CommonName: host, @@ -52,7 +55,7 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { x509.ExtKeyUsageServerAuth, }, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), + SerialNumber: sn, NotBefore: time.Now().Add(-30 * time.Second), NotAfter: time.Now().Add(262980 * time.Hour), IsCA: true, @@ -66,9 +69,9 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { return certBytes, key, nil } -// CreateClientTLSConfig creates a signed certificate and returns a configured +// createClientTLSConfig creates a signed certificate and returns a configured // TLS config. -func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { +func createClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { clientCert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, fmt.Errorf("error parsing generated plugin certificate: %v", err) @@ -95,9 +98,9 @@ func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config return tlsConfig, nil } -// WrapServerConfig is used to create a server certificate and private key, then +// wrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { +func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err @@ -120,11 +123,6 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er return func() (*tls.Config, error) { unwrapToken := os.Getenv(PluginUnwrapTokenEnv) - // Ensure unwrap token is a JWT - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } - // Parse the JWT and retrieve the vault address wt, err := jws.ParseJWT([]byte(unwrapToken)) if err != nil { From ce391ca425009f5180dbddd7e1640578bd347988 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 18:41:39 -0700 Subject: [PATCH 146/162] add new mysql plugin names and fix grammar --- .../docs/secrets/databases/cassandra.html.md | 2 +- .../source/docs/secrets/databases/mssql.html.md | 2 +- .../docs/secrets/databases/mysql-maria.html.md | 15 +++++++++++++-- .../docs/secrets/databases/postgresql.html.md | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 1d8468ad3..0e29d0300 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -57,6 +57,6 @@ This role can be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [Cassandra database plugin API](/api/secret/databases/cassandra.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index fec8924b3..c2f7ff5fe 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -55,6 +55,6 @@ This role can now be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [MSSQL database plugin API](/api/secret/databases/mssql.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index c5eea4b7b..ae6c19eac 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -8,7 +8,8 @@ description: |- # MySQL/MariaDB Database Plugin -Name: `mysql-database-plugin` +Name: `mysql-database-plugin`, `aurora-database-plugin`, `rds-database-plugin`, +`mysql-legacy-database-plugin` The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on @@ -17,6 +18,16 @@ configured roles for the MySQL database. See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. +This plugin has a few different instances built into vault, each instance is for +a slightly different MySQL driver. The only difference between these plugins is +the length of usernames generated by the plugin as different versions of mysql +accept different lengths. The availible plugins are: + + - mysql-database-plugin + - aurora-database-plugin + - rds-database-plugin + - mysql-legacy-database-plugin + ## Quick Start After the Database Backend is mounted you can configure a MySQL connection @@ -53,6 +64,6 @@ This role can now be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [MySQL database plugin API](/api/secret/databases/mysql-maria.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 72601e34f..e04cc087c 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -55,6 +55,6 @@ This role can be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [PostgreSQL database plugin API](/api/secret/databases/postgresql.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. From 29bfc0a0d4eacc1e9cf6c08f863a73010c702f32 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 10:41:59 -0700 Subject: [PATCH 147/162] PR comments --- builtin/logical/database/path_roles.go | 2 +- helper/builtinplugins/builtin.go | 4 ++-- logical/system_view.go | 2 +- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index c81261804..8be33c0a1 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -57,7 +57,7 @@ func pathRoles(b *databaseBackend) *framework.Path { }, "rollback_statements": { Type: framework.TypeString, - Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + Description: `Statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 8e6ed22ef..a2100e931 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -13,8 +13,8 @@ var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ // These four plugins all use the same mysql implementation but with // different username settings passed by the constructor. "mysql-database-plugin": mysql.New(mysql.DisplayNameLen, mysql.UsernameLen), - "aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), - "rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), "mysql-legacy-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), "postgresql-database-plugin": postgresql.New, diff --git a/logical/system_view.go b/logical/system_view.go index 175edc0f9..64fc51c7b 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -49,7 +49,7 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) - // MlockEnabled returns the configuration setting for Enableing mlock on + // MlockEnabled returns the configuration setting for enabling mlock on // plugins. MlockEnabled() bool } diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index ae6c19eac..f4cf3640b 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -8,7 +8,7 @@ description: |- # MySQL/MariaDB Database Plugin -Name: `mysql-database-plugin`, `aurora-database-plugin`, `rds-database-plugin`, +Name: `mysql-database-plugin`, `mysql-aurora-database-plugin`, `mysql-rds-database-plugin`, `mysql-legacy-database-plugin` The MySQL Database Plugin is one of the supported plugins for the Database @@ -24,8 +24,8 @@ the length of usernames generated by the plugin as different versions of mysql accept different lengths. The availible plugins are: - mysql-database-plugin - - aurora-database-plugin - - rds-database-plugin + - mysql-aurora-database-plugin + - mysql-rds-database-plugin - mysql-legacy-database-plugin ## Quick Start From 4c0e3c5d2fe2a418e1e3de9356d55a70fc3927ac Mon Sep 17 00:00:00 2001 From: mymercurialsky Date: Thu, 4 May 2017 10:49:42 -0700 Subject: [PATCH 148/162] Implemented TOTP Secret Backend (#2492) * Initialized basic outline of TOTP backend using Postgresql backend as template * Updated TOTP backend.go's structure and help string * Updated TOTP path_roles.go's structure and help strings * Updated TOTP path_role_create.go's structure and help strings * Fixed typo in path_roles.go * Fixed errors in path_role_create.go and path_roles.go * Added TOTP secret backend information to cli commands * Fixed build errors in path_roles.go and path_role_create.go * Changed field values of period and digits from uint to int, added uint conversion of period when generating passwords * Initialized TOTP test file based on structure of postgresql test file * Added enforcement of input values * Added otp library to vendor folder * Added test steps and cleaned up errors * Modified read credential test step, not working yet * Use of vendored package not allowed - Test error * Removed vendor files for TOTP library * Revert "Removed vendor files for TOTP library" This reverts commit fcd030994bc1741dbf490f3995944e091b11da61. * Hopefully fixed vendor folder issue with TOTP Library * Added additional tests for TOTP backend * Cleaned up comments in TOTP backend_test.go * Added default values of period, algorithm and digits to field schema * Changed account_name and issuer fields to optional * Removed MD5 as a hash algorithm option * Implemented requested pull request changes * Added ability to validate TOTP codes * Added ability to have a key generated * Added skew, qr size and key size parameters * Reset vendor.json prior to merge * Readded otp and barcode libraries to vendor.json * Modified help strings for path_role_create.go * Fixed test issue in testAccStepReadRole * Cleaned up error formatting, variable names and path names. Also added some additional documentation * Moveed barcode and url output to key creation function and did some additional cleanup based on requested changes * Added ability to pass in TOTP urls * Added additional tests for TOTP server functions * Removed unused QRSize, URL and Generate members of keyEntry struct * Removed unnecessary urlstring variable from pathKeyCreate * Added website documentation for TOTP secret backend * Added errors if generate is true and url or key is passed, removed logger from backend, and revised parameter documentation. * Updated website documentation and added QR example * Added exported variable and ability to disable QR generation, cleaned up error reporting, changed default skew value, updated documentation and added additional tests * Updated API documentation to inlude to exported variable and qr size option * Cleaned up return statements in path_code, added error handling while validating codes and clarified documentation for generate parameters in path_keys --- builtin/logical/totp/backend.go | 37 + builtin/logical/totp/backend_test.go | 1128 +++++++++++++++++ builtin/logical/totp/path_code.go | 110 ++ builtin/logical/totp/path_keys.go | 424 +++++++ cli/commands.go | 2 + vendor/github.com/boombuler/barcode/LICENSE | 21 + vendor/github.com/boombuler/barcode/README.md | 18 + .../github.com/boombuler/barcode/barcode.go | 27 + .../boombuler/barcode/qr/alphanumeric.go | 66 + .../boombuler/barcode/qr/automatic.go | 23 + .../github.com/boombuler/barcode/qr/blocks.go | 59 + .../boombuler/barcode/qr/encoder.go | 416 ++++++ .../boombuler/barcode/qr/errorcorrection.go | 29 + .../boombuler/barcode/qr/numeric.go | 56 + .../github.com/boombuler/barcode/qr/qrcode.go | 166 +++ .../boombuler/barcode/qr/unicode.go | 27 + .../boombuler/barcode/qr/versioninfo.go | 310 +++++ .../boombuler/barcode/scaledbarcode.go | 134 ++ .../boombuler/barcode/utils/base1dcode.go | 57 + .../boombuler/barcode/utils/bitlist.go | 119 ++ .../boombuler/barcode/utils/galoisfield.go | 65 + .../boombuler/barcode/utils/gfpoly.go | 103 ++ .../boombuler/barcode/utils/reedsolomon.go | 44 + .../boombuler/barcode/utils/runeint.go | 19 + vendor/github.com/pquerna/otp/LICENSE | 202 +++ vendor/github.com/pquerna/otp/NOTICE | 5 + vendor/github.com/pquerna/otp/README.md | 60 + vendor/github.com/pquerna/otp/doc.go | 70 + vendor/github.com/pquerna/otp/example/main.go | 63 + vendor/github.com/pquerna/otp/hotp/hotp.go | 187 +++ vendor/github.com/pquerna/otp/otp.go | 200 +++ vendor/github.com/pquerna/otp/totp/totp.go | 191 +++ vendor/vendor.json | 36 + website/source/api/secret/totp/index.html.md | 272 ++++ .../source/docs/secrets/totp/index.html.md | 83 ++ website/source/layouts/api.erb | 3 + website/source/layouts/docs.erb | 4 + 37 files changed, 4836 insertions(+) create mode 100644 builtin/logical/totp/backend.go create mode 100644 builtin/logical/totp/backend_test.go create mode 100644 builtin/logical/totp/path_code.go create mode 100644 builtin/logical/totp/path_keys.go create mode 100644 vendor/github.com/boombuler/barcode/LICENSE create mode 100644 vendor/github.com/boombuler/barcode/README.md create mode 100644 vendor/github.com/boombuler/barcode/barcode.go create mode 100644 vendor/github.com/boombuler/barcode/qr/alphanumeric.go create mode 100644 vendor/github.com/boombuler/barcode/qr/automatic.go create mode 100644 vendor/github.com/boombuler/barcode/qr/blocks.go create mode 100644 vendor/github.com/boombuler/barcode/qr/encoder.go create mode 100644 vendor/github.com/boombuler/barcode/qr/errorcorrection.go create mode 100644 vendor/github.com/boombuler/barcode/qr/numeric.go create mode 100644 vendor/github.com/boombuler/barcode/qr/qrcode.go create mode 100644 vendor/github.com/boombuler/barcode/qr/unicode.go create mode 100644 vendor/github.com/boombuler/barcode/qr/versioninfo.go create mode 100644 vendor/github.com/boombuler/barcode/scaledbarcode.go create mode 100644 vendor/github.com/boombuler/barcode/utils/base1dcode.go create mode 100644 vendor/github.com/boombuler/barcode/utils/bitlist.go create mode 100644 vendor/github.com/boombuler/barcode/utils/galoisfield.go create mode 100644 vendor/github.com/boombuler/barcode/utils/gfpoly.go create mode 100644 vendor/github.com/boombuler/barcode/utils/reedsolomon.go create mode 100644 vendor/github.com/boombuler/barcode/utils/runeint.go create mode 100644 vendor/github.com/pquerna/otp/LICENSE create mode 100644 vendor/github.com/pquerna/otp/NOTICE create mode 100644 vendor/github.com/pquerna/otp/README.md create mode 100644 vendor/github.com/pquerna/otp/doc.go create mode 100644 vendor/github.com/pquerna/otp/example/main.go create mode 100644 vendor/github.com/pquerna/otp/hotp/hotp.go create mode 100644 vendor/github.com/pquerna/otp/otp.go create mode 100644 vendor/github.com/pquerna/otp/totp/totp.go create mode 100644 website/source/api/secret/totp/index.html.md create mode 100644 website/source/docs/secrets/totp/index.html.md diff --git a/builtin/logical/totp/backend.go b/builtin/logical/totp/backend.go new file mode 100644 index 000000000..4e3554bdb --- /dev/null +++ b/builtin/logical/totp/backend.go @@ -0,0 +1,37 @@ +package totp + +import ( + "strings" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(conf *logical.BackendConfig) (logical.Backend, error) { + return Backend(conf).Setup(conf) +} + +func Backend(conf *logical.BackendConfig) *backend { + var b backend + b.Backend = &framework.Backend{ + Help: strings.TrimSpace(backendHelp), + + Paths: []*framework.Path{ + pathListKeys(&b), + pathKeys(&b), + pathCode(&b), + }, + + Secrets: []*framework.Secret{}, + } + + return &b +} + +type backend struct { + *framework.Backend +} + +const backendHelp = ` +The TOTP backend dynamically generates time-based one-time use passwords. +` diff --git a/builtin/logical/totp/backend_test.go b/builtin/logical/totp/backend_test.go new file mode 100644 index 000000000..2a18a056f --- /dev/null +++ b/builtin/logical/totp/backend_test.go @@ -0,0 +1,1128 @@ +package totp + +import ( + "fmt" + "log" + "net/url" + "path" + "testing" + "time" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/mitchellh/mapstructure" + otplib "github.com/pquerna/otp" + totplib "github.com/pquerna/otp/totp" +) + +func createKey() (string, error) { + keyUrl, err := totplib.Generate(totplib.GenerateOpts{ + Issuer: "Vault", + AccountName: "Test", + }) + + key := keyUrl.Secret() + + return key, err +} + +func generateCode(key string, period uint, digits otplib.Digits, algorithm otplib.Algorithm) (string, error) { + // Generate password using totp library + totpToken, err := totplib.GenerateCodeCustom(key, time.Now(), totplib.ValidateOpts{ + Period: period, + Digits: digits, + Algorithm: algorithm, + }) + + return totpToken, err +} + +func TestBackend_readCredentialsDefaultValues(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "key": key, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "", + "account_name": "", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + "key": key, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_readCredentialsEightDigitsThirtySecondPeriod(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "digits": 8, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsEight, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + "key": key, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_readCredentialsSixDigitsNinetySecondPeriod(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "period": 90, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 90, + "algorithm": otplib.AlgorithmSHA1, + "key": key, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_readCredentialsSHA256(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "algorithm": "SHA256", + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA256, + "key": key, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_readCredentialsSHA512(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "algorithm": "SHA512", + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA512, + "key": key, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_keyCrudDefaultValues(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + "key": key, + } + + code, _ := generateCode(key, 30, otplib.DigitsSix, otplib.AlgorithmSHA1) + invalidCode := "12345678" + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepValidateCode(t, "test", code, true), + testAccStepValidateCode(t, "test", invalidCode, false), + testAccStepDeleteKey(t, "test"), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_createKeyMissingKeyValue(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_createKeyInvalidKeyValue(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": "1", + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_createKeyInvalidAlgorithm(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "algorithm": "BADALGORITHM", + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_createKeyInvalidPeriod(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "period": -1, + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_createKeyInvalidDigits(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + // Generate a new shared key + key, _ := createKey() + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key": key, + "digits": 20, + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyDefaultValues(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "generate": true, + "key_size": 20, + "exported": true, + "qr_size": 200, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + }, + }) +} + +func TestBackend_generatedKeyDefaultValuesNoQR(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "generate": true, + "key_size": 20, + "exported": true, + "qr_size": 0, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + }, + }) +} + +func TestBackend_generatedKeyNonDefaultKeySize(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "generate": true, + "key_size": 10, + "exported": true, + "qr_size": 200, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyInvalidPeriod(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/Vault:test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=AZ" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyInvalidDigits(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/Vault:test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=Q&period=60" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyIssuerInFirstPosition(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/Vault:test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "test@email.com", + "digits": otplib.DigitsSix, + "period": 60, + "algorithm": otplib.AlgorithmSHA512, + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyIssuerInQueryString(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60&issuer=Vault" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "test@email.com", + "digits": otplib.DigitsSix, + "period": 60, + "algorithm": otplib.AlgorithmSHA512, + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyMissingIssuer(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "", + "account_name": "test@email.com", + "digits": otplib.DigitsSix, + "period": 60, + "algorithm": otplib.AlgorithmSHA512, + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyMissingAccountName(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/Vault:?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "", + "digits": otplib.DigitsSix, + "period": 60, + "algorithm": otplib.AlgorithmSHA512, + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_urlPassedNonGeneratedKeyMissingAccountNameandIssuer(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + urlString := "otpauth://totp/?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60" + + keyData := map[string]interface{}{ + "url": urlString, + "generate": false, + } + + expected := map[string]interface{}{ + "issuer": "", + "account_name": "", + "digits": otplib.DigitsSix, + "period": 60, + "algorithm": otplib.AlgorithmSHA512, + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + testAccStepReadCreds(t, b, config.StorageView, "test", expected), + }, + }) +} + +func TestBackend_generatedKeyInvalidSkew(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "skew": "2", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyInvalidQRSize(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "qr_size": "-100", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyInvalidKeySize(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "Test", + "key_size": "-100", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyMissingAccountName(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyMissingIssuer(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "account_name": "test@email.com", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_invalidURLValue(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "url": "notaurl", + "generate": false, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_urlAndGenerateTrue(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "url": "otpauth://totp/Vault:test@email.com?secret=HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ&algorithm=SHA512&digits=6&period=60", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_keyAndGenerateTrue(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "key": "HXDMVJECJJWSRB3HWIZR4IFUGFTMXBOZ", + "generate": true, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, true), + testAccStepReadKey(t, "test", nil), + }, + }) +} + +func TestBackend_generatedKeyExportedFalse(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + keyData := map[string]interface{}{ + "issuer": "Vault", + "account_name": "test@email.com", + "generate": true, + "exported": false, + } + + expected := map[string]interface{}{ + "issuer": "Vault", + "account_name": "test@email.com", + "digits": otplib.DigitsSix, + "period": 30, + "algorithm": otplib.AlgorithmSHA1, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepCreateKey(t, "test", keyData, false), + testAccStepReadKey(t, "test", expected), + }, + }) +} + +func testAccStepCreateKey(t *testing.T, name string, keyData map[string]interface{}, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("keys", name), + Data: keyData, + ErrorOk: expectFail, + Check: func(resp *logical.Response) error { + //Skip this if the key is not generated by vault or if the test is expected to fail + if !keyData["generate"].(bool) || expectFail { + return nil + } + + // Check to see if barcode and url were returned if exported is false + if !keyData["exported"].(bool) { + if resp != nil { + t.Fatalf("data was returned when exported was set to false") + } + return nil + } + + // Check to see if a barcode was returned when qr_size is zero + if keyData["qr_size"].(int) == 0 { + if _, exists := resp.Data["barcode"]; exists { + t.Fatalf("a barcode was returned when qr_size was set to zero") + } + return nil + } + + var d struct { + Url string `mapstructure:"url"` + Barcode string `mapstructure:"barcode"` + } + + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + //Check to see if barcode and url are returned + if d.Barcode == "" { + t.Fatalf("a barcode was not returned for a generated key") + } + + if d.Url == "" { + t.Fatalf("a url was not returned for a generated key") + } + + //Parse url + urlObject, err := url.Parse(d.Url) + + if err != nil { + t.Fatal("an error occured while parsing url string") + } + + //Set up query object + urlQuery := urlObject.Query() + + //Read secret + urlSecret := urlQuery.Get("secret") + + //Check key length + keySize := keyData["key_size"].(int) + correctSecretStringSize := (keySize / 5) * 8 + actualSecretStringSize := len(urlSecret) + + if actualSecretStringSize != correctSecretStringSize { + t.Fatal("incorrect key string length") + } + + return nil + }, + } +} + +func testAccStepDeleteKey(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: path.Join("keys", name), + } +} + +func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, validation map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("code", name), + Check: func(resp *logical.Response) error { + var d struct { + Code string `mapstructure:"code"` + } + + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + log.Printf("[TRACE] Generated credentials: %v", d) + + period := validation["period"].(int) + key := validation["key"].(string) + algorithm := validation["algorithm"].(otplib.Algorithm) + digits := validation["digits"].(otplib.Digits) + + valid, _ := totplib.ValidateCustom(d.Code, key, time.Now(), totplib.ValidateOpts{ + Period: uint(period), + Skew: 1, + Digits: digits, + Algorithm: algorithm, + }) + + if !valid { + t.Fatalf("generated code isn't valid") + } + + return nil + }, + } +} + +func testAccStepReadKey(t *testing.T, name string, expected map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "keys/" + name, + Check: func(resp *logical.Response) error { + if resp == nil { + if expected == nil { + return nil + } + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + Issuer string `mapstructure:"issuer"` + AccountName string `mapstructure:"account_name"` + Period uint `mapstructure:"period"` + Algorithm string `mapstructure:"algorithm"` + Digits otplib.Digits `mapstructure:"digits"` + } + + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + var keyAlgorithm otplib.Algorithm + switch d.Algorithm { + case "SHA1": + keyAlgorithm = otplib.AlgorithmSHA1 + case "SHA256": + keyAlgorithm = otplib.AlgorithmSHA256 + case "SHA512": + keyAlgorithm = otplib.AlgorithmSHA512 + } + + period := expected["period"].(int) + + switch { + case d.Issuer != expected["issuer"]: + return fmt.Errorf("issuer should equal: %s", expected["issuer"]) + case d.AccountName != expected["account_name"]: + return fmt.Errorf("account_name should equal: %s", expected["account_name"]) + case d.Period != uint(period): + return fmt.Errorf("period should equal: %d", expected["period"]) + case keyAlgorithm != expected["algorithm"]: + return fmt.Errorf("algorithm should equal: %s", expected["algorithm"]) + case d.Digits != expected["digits"]: + return fmt.Errorf("digits should equal: %d", expected["digits"]) + } + return nil + }, + } +} + +func testAccStepValidateCode(t *testing.T, name string, code string, valid bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "code/" + name, + Data: map[string]interface{}{ + "code": code, + }, + Check: func(resp *logical.Response) error { + if resp == nil { + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + Valid bool `mapstructure:"valid"` + } + + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + switch valid { + case true: + if d.Valid != true { + return fmt.Errorf("code was not valid: %s", code) + } + + default: + if d.Valid != false { + return fmt.Errorf("code was incorrectly validated: %s", code) + } + } + return nil + }, + } +} diff --git a/builtin/logical/totp/path_code.go b/builtin/logical/totp/path_code.go new file mode 100644 index 000000000..0481db145 --- /dev/null +++ b/builtin/logical/totp/path_code.go @@ -0,0 +1,110 @@ +package totp + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + otplib "github.com/pquerna/otp" + totplib "github.com/pquerna/otp/totp" +) + +func pathCode(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "code/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key.", + }, + "code": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "TOTP code to be validated.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathReadCode, + logical.UpdateOperation: b.pathValidateCode, + }, + + HelpSynopsis: pathCodeHelpSyn, + HelpDescription: pathCodeHelpDesc, + } +} + +func (b *backend) pathReadCode( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + + // Get the key + key, err := b.Key(req.Storage, name) + if err != nil { + return nil, err + } + if key == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown key: %s", name)), nil + } + + // Generate password using totp library + totpToken, err := totplib.GenerateCodeCustom(key.Key, time.Now(), totplib.ValidateOpts{ + Period: key.Period, + Digits: key.Digits, + Algorithm: key.Algorithm, + }) + if err != nil { + return nil, err + } + + // Return the secret + return &logical.Response{ + Data: map[string]interface{}{ + "code": totpToken, + }, + }, nil +} + +func (b *backend) pathValidateCode( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + code := data.Get("code").(string) + + // Enforce input value requirements + if code == "" { + return logical.ErrorResponse("the code value is required"), nil + } + + // Get the key's stored values + key, err := b.Key(req.Storage, name) + if err != nil { + return nil, err + } + if key == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown key: %s", name)), nil + } + + valid, err := totplib.ValidateCustom(code, key.Key, time.Now(), totplib.ValidateOpts{ + Period: key.Period, + Skew: key.Skew, + Digits: key.Digits, + Algorithm: key.Algorithm, + }) + if err != nil && err != otplib.ErrValidateInputInvalidLength { + return logical.ErrorResponse("an error occured while validating the code"), err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "valid": valid, + }, + }, nil +} + +const pathCodeHelpSyn = ` +Request time-based one-time use password or validate a password for a certain key . +` +const pathCodeHelpDesc = ` +This path generates and validates time-based one-time use passwords for a certain key. + +` diff --git a/builtin/logical/totp/path_keys.go b/builtin/logical/totp/path_keys.go new file mode 100644 index 000000000..3f36aef0f --- /dev/null +++ b/builtin/logical/totp/path_keys.go @@ -0,0 +1,424 @@ +package totp + +import ( + "bytes" + "encoding/base32" + "encoding/base64" + "fmt" + "image/png" + "net/url" + "strconv" + "strings" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + otplib "github.com/pquerna/otp" + totplib "github.com/pquerna/otp/totp" +) + +func pathListKeys(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "keys/?$", + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.pathKeyList, + }, + + HelpSynopsis: pathKeyHelpSyn, + HelpDescription: pathKeyHelpDesc, + } +} + +func pathKeys(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "keys/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the key.", + }, + + "generate": { + Type: framework.TypeBool, + Default: false, + Description: "Determines if a key should be generated by Vault or if a key is being passed from another service.", + }, + + "exported": { + Type: framework.TypeBool, + Default: true, + Description: "Determines if a QR code and url are returned upon generating a key. Only used if generate is true.", + }, + + "key_size": { + Type: framework.TypeInt, + Default: 20, + Description: "Determines the size in bytes of the generated key. Only used if generate is true.", + }, + + "key": { + Type: framework.TypeString, + Description: "The shared master key used to generate a TOTP token. Only used if generate is false.", + }, + + "issuer": { + Type: framework.TypeString, + Description: `The name of the key's issuing organization. Required if generate is true.`, + }, + + "account_name": { + Type: framework.TypeString, + Description: `The name of the account associated with the key. Required if generate is true.`, + }, + + "period": { + Type: framework.TypeDurationSecond, + Default: 30, + Description: `The length of time used to generate a counter for the TOTP token calculation.`, + }, + + "algorithm": { + Type: framework.TypeString, + Default: "SHA1", + Description: `The hashing algorithm used to generate the TOTP token. Options include SHA1, SHA256 and SHA512.`, + }, + + "digits": { + Type: framework.TypeInt, + Default: 6, + Description: `The number of digits in the generated TOTP token. This value can either be 6 or 8.`, + }, + + "skew": { + Type: framework.TypeInt, + Default: 1, + Description: `The number of delay periods that are allowed when validating a TOTP token. This value can either be 0 or 1. Only used if generate is true.`, + }, + + "qr_size": { + Type: framework.TypeInt, + Default: 200, + Description: `The pixel size of the generated square QR code. Only used if generate is true and exported is true. If this value is 0, a QR code will not be returned.`, + }, + + "url": { + Type: framework.TypeString, + Description: `A TOTP url string containing all of the parameters for key setup. Only used if generate is false.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathKeyRead, + logical.UpdateOperation: b.pathKeyCreate, + logical.DeleteOperation: b.pathKeyDelete, + }, + + HelpSynopsis: pathKeyHelpSyn, + HelpDescription: pathKeyHelpDesc, + } +} + +func (b *backend) Key(s logical.Storage, n string) (*keyEntry, error) { + entry, err := s.Get("key/" + n) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result keyEntry + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +func (b *backend) pathKeyDelete( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("key/" + data.Get("name").(string)) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *backend) pathKeyRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + key, err := b.Key(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if key == nil { + return nil, nil + } + + // Translate algorithm back to string + algorithm := key.Algorithm.String() + + // Return values of key + return &logical.Response{ + Data: map[string]interface{}{ + "issuer": key.Issuer, + "account_name": key.AccountName, + "period": key.Period, + "algorithm": algorithm, + "digits": key.Digits, + }, + }, nil +} + +func (b *backend) pathKeyList( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("key/") + if err != nil { + return nil, err + } + + return logical.ListResponse(entries), nil +} + +func (b *backend) pathKeyCreate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + generate := data.Get("generate").(bool) + exported := data.Get("exported").(bool) + keyString := data.Get("key").(string) + issuer := data.Get("issuer").(string) + accountName := data.Get("account_name").(string) + period := data.Get("period").(int) + algorithm := data.Get("algorithm").(string) + digits := data.Get("digits").(int) + skew := data.Get("skew").(int) + qrSize := data.Get("qr_size").(int) + keySize := data.Get("key_size").(int) + inputURL := data.Get("url").(string) + + if generate { + if keyString != "" { + return logical.ErrorResponse("a key should not be passed if generate is true"), nil + } + if inputURL != "" { + return logical.ErrorResponse("a url should not be passed if generate is true"), nil + } + } + + // Read parameters from url if given + if inputURL != "" { + //Parse url + urlObject, err := url.Parse(inputURL) + if err != nil { + return logical.ErrorResponse("an error occured while parsing url string"), err + } + + //Set up query object + urlQuery := urlObject.Query() + path := strings.TrimPrefix(urlObject.Path, "/") + index := strings.Index(path, ":") + + //Read issuer + urlIssuer := urlQuery.Get("issuer") + if urlIssuer != "" { + issuer = urlIssuer + } else { + if index != -1 { + issuer = path[:index] + } + } + + //Read account name + if index == -1 { + accountName = path + } else { + accountName = path[index+1:] + } + + //Read key string + keyString = urlQuery.Get("secret") + + //Read period + periodQuery := urlQuery.Get("period") + if periodQuery != "" { + periodInt, err := strconv.Atoi(periodQuery) + if err != nil { + return logical.ErrorResponse("an error occured while parsing period value in url"), err + } + period = periodInt + } + + //Read digits + digitsQuery := urlQuery.Get("digits") + if digitsQuery != "" { + digitsInt, err := strconv.Atoi(digitsQuery) + if err != nil { + return logical.ErrorResponse("an error occured while parsing digits value in url"), err + } + digits = digitsInt + } + + //Read algorithm + algorithmQuery := urlQuery.Get("algorithm") + if algorithmQuery != "" { + algorithm = algorithmQuery + } + } + + // Translate digits and algorithm to a format the totp library understands + var keyDigits otplib.Digits + switch digits { + case 6: + keyDigits = otplib.DigitsSix + case 8: + keyDigits = otplib.DigitsEight + default: + return logical.ErrorResponse("the digits value can only be 6 or 8"), nil + } + + var keyAlgorithm otplib.Algorithm + switch algorithm { + case "SHA1": + keyAlgorithm = otplib.AlgorithmSHA1 + case "SHA256": + keyAlgorithm = otplib.AlgorithmSHA256 + case "SHA512": + keyAlgorithm = otplib.AlgorithmSHA512 + default: + return logical.ErrorResponse("the algorithm value is not valid"), nil + } + + // Enforce input value requirements + if period <= 0 { + return logical.ErrorResponse("the period value must be greater than zero"), nil + } + + switch skew { + case 0: + case 1: + default: + return logical.ErrorResponse("the skew value must be 0 or 1"), nil + } + + // QR size can be zero but it shouldn't be negative + if qrSize < 0 { + return logical.ErrorResponse("the qr_size value must be greater than or equal to zero"), nil + } + + if keySize <= 0 { + return logical.ErrorResponse("the key_size value must be greater than zero"), nil + } + + // Period, Skew and Key Size need to be unsigned ints + uintPeriod := uint(period) + uintSkew := uint(skew) + uintKeySize := uint(keySize) + + var response *logical.Response + + switch generate { + case true: + // If the key is generated, Account Name and Issuer are required. + if accountName == "" { + return logical.ErrorResponse("the account_name value is required for generated keys"), nil + } + + if issuer == "" { + return logical.ErrorResponse("the issuer value is required for generated keys"), nil + } + + // Generate a new key + keyObject, err := totplib.Generate(totplib.GenerateOpts{ + Issuer: issuer, + AccountName: accountName, + Period: uintPeriod, + Digits: keyDigits, + Algorithm: keyAlgorithm, + SecretSize: uintKeySize, + }) + if err != nil { + return logical.ErrorResponse("an error occured while generating a key"), err + } + + // Get key string value + keyString = keyObject.Secret() + + // Skip returning the QR code and url if exported is set to false + if exported { + // Prepare the url and barcode + urlString := keyObject.String() + + // Don't include QR code is size is set to zero + if qrSize == 0 { + response = &logical.Response{ + Data: map[string]interface{}{ + "url": urlString, + }, + } + } else { + barcode, err := keyObject.Image(qrSize, qrSize) + if err != nil { + return logical.ErrorResponse("an error occured while generating a QR code image"), err + } + + var buff bytes.Buffer + png.Encode(&buff, barcode) + b64Barcode := base64.StdEncoding.EncodeToString(buff.Bytes()) + response = &logical.Response{ + Data: map[string]interface{}{ + "url": urlString, + "barcode": b64Barcode, + }, + } + } + } + default: + if keyString == "" { + return logical.ErrorResponse("the key value is required"), nil + } + + _, err := base32.StdEncoding.DecodeString(keyString) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "invalid key value: %s", err)), nil + } + } + + // Store it + entry, err := logical.StorageEntryJSON("key/"+name, &keyEntry{ + Key: keyString, + Issuer: issuer, + AccountName: accountName, + Period: uintPeriod, + Algorithm: keyAlgorithm, + Digits: keyDigits, + Skew: uintSkew, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return response, nil +} + +type keyEntry struct { + Key string `json:"key" mapstructure:"key" structs:"key"` + Issuer string `json:"issuer" mapstructure:"issuer" structs:"issuer"` + AccountName string `json:"account_name" mapstructure:"account_name" structs:"account_name"` + Period uint `json:"period" mapstructure:"period" structs:"period"` + Algorithm otplib.Algorithm `json:"algorithm" mapstructure:"algorithm" structs:"algorithm"` + Digits otplib.Digits `json:"digits" mapstructure:"digits" structs:"digits"` + Skew uint `json:"skew" mapstructure:"skew" structs:"skew"` +} + +const pathKeyHelpSyn = ` +Manage the keys that can be created with this backend. +` + +const pathKeyHelpDesc = ` +This path lets you manage the keys that can be created with this backend. + +` diff --git a/cli/commands.go b/cli/commands.go index 7494c0676..72abe953a 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -28,6 +28,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/postgresql" "github.com/hashicorp/vault/builtin/logical/rabbitmq" "github.com/hashicorp/vault/builtin/logical/ssh" + "github.com/hashicorp/vault/builtin/logical/totp" "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/audit" @@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "mysql": mysql.Factory, "ssh": ssh.Factory, "rabbitmq": rabbitmq.Factory, + "totp": totp.Factory, }, ShutdownCh: command.MakeShutdownCh(), SighupCh: command.MakeSighupCh(), diff --git a/vendor/github.com/boombuler/barcode/LICENSE b/vendor/github.com/boombuler/barcode/LICENSE new file mode 100644 index 000000000..862b0ddcd --- /dev/null +++ b/vendor/github.com/boombuler/barcode/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Florian Sundermann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/boombuler/barcode/README.md b/vendor/github.com/boombuler/barcode/README.md new file mode 100644 index 000000000..85c34d639 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/README.md @@ -0,0 +1,18 @@ +##Introduction## +This is a package for GO which can be used to create different types of barcodes. + +##Supported Barcode Types## +* Aztec Code +* Codabar +* Code 128 +* Code 39 +* EAN 8 +* EAN 13 +* Datamatrix +* QR Codes +* 2 of 5 + +##Documentation## +See [GoDoc](https://godoc.org/github.com/boombuler/barcode) + +To create a barcode use the Encode function from one of the subpackages. diff --git a/vendor/github.com/boombuler/barcode/barcode.go b/vendor/github.com/boombuler/barcode/barcode.go new file mode 100644 index 000000000..3479c7bc2 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/barcode.go @@ -0,0 +1,27 @@ +package barcode + +import "image" + +// Contains some meta information about a barcode +type Metadata struct { + // the name of the barcode kind + CodeKind string + // contains 1 for 1D barcodes or 2 for 2D barcodes + Dimensions byte +} + +// a rendered and encoded barcode +type Barcode interface { + image.Image + // returns some meta information about the barcode + Metadata() Metadata + // the data that was encoded in this barcode + Content() string +} + +// Additional interface that some barcodes might implement to provide +// the value of its checksum. +type BarcodeIntCS interface { + Barcode + CheckSum() int +} diff --git a/vendor/github.com/boombuler/barcode/qr/alphanumeric.go b/vendor/github.com/boombuler/barcode/qr/alphanumeric.go new file mode 100644 index 000000000..4ded7c8e0 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/alphanumeric.go @@ -0,0 +1,66 @@ +package qr + +import ( + "errors" + "fmt" + "strings" + + "github.com/boombuler/barcode/utils" +) + +const charSet string = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ $%*+-./:" + +func stringToAlphaIdx(content string) <-chan int { + result := make(chan int) + go func() { + for _, r := range content { + idx := strings.IndexRune(charSet, r) + result <- idx + if idx < 0 { + break + } + } + close(result) + }() + + return result +} + +func encodeAlphaNumeric(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) { + + contentLenIsOdd := len(content)%2 == 1 + contentBitCount := (len(content) / 2) * 11 + if contentLenIsOdd { + contentBitCount += 6 + } + vi := findSmallestVersionInfo(ecl, alphaNumericMode, contentBitCount) + if vi == nil { + return nil, nil, errors.New("To much data to encode") + } + + res := new(utils.BitList) + res.AddBits(int(alphaNumericMode), 4) + res.AddBits(len(content), vi.charCountBits(alphaNumericMode)) + + encoder := stringToAlphaIdx(content) + + for idx := 0; idx < len(content)/2; idx++ { + c1 := <-encoder + c2 := <-encoder + if c1 < 0 || c2 < 0 { + return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, AlphaNumeric) + } + res.AddBits(c1*45+c2, 11) + } + if contentLenIsOdd { + c := <-encoder + if c < 0 { + return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, AlphaNumeric) + } + res.AddBits(c, 6) + } + + addPaddingAndTerminator(res, vi) + + return res, vi, nil +} diff --git a/vendor/github.com/boombuler/barcode/qr/automatic.go b/vendor/github.com/boombuler/barcode/qr/automatic.go new file mode 100644 index 000000000..e7c56013f --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/automatic.go @@ -0,0 +1,23 @@ +package qr + +import ( + "fmt" + + "github.com/boombuler/barcode/utils" +) + +func encodeAuto(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) { + bits, vi, _ := Numeric.getEncoder()(content, ecl) + if bits != nil && vi != nil { + return bits, vi, nil + } + bits, vi, _ = AlphaNumeric.getEncoder()(content, ecl) + if bits != nil && vi != nil { + return bits, vi, nil + } + bits, vi, _ = Unicode.getEncoder()(content, ecl) + if bits != nil && vi != nil { + return bits, vi, nil + } + return nil, nil, fmt.Errorf("No encoding found to encode \"%s\"", content) +} diff --git a/vendor/github.com/boombuler/barcode/qr/blocks.go b/vendor/github.com/boombuler/barcode/qr/blocks.go new file mode 100644 index 000000000..d3173787f --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/blocks.go @@ -0,0 +1,59 @@ +package qr + +type block struct { + data []byte + ecc []byte +} +type blockList []*block + +func splitToBlocks(data <-chan byte, vi *versionInfo) blockList { + result := make(blockList, vi.NumberOfBlocksInGroup1+vi.NumberOfBlocksInGroup2) + + for b := 0; b < int(vi.NumberOfBlocksInGroup1); b++ { + blk := new(block) + blk.data = make([]byte, vi.DataCodeWordsPerBlockInGroup1) + for cw := 0; cw < int(vi.DataCodeWordsPerBlockInGroup1); cw++ { + blk.data[cw] = <-data + } + blk.ecc = ec.calcECC(blk.data, vi.ErrorCorrectionCodewordsPerBlock) + result[b] = blk + } + + for b := 0; b < int(vi.NumberOfBlocksInGroup2); b++ { + blk := new(block) + blk.data = make([]byte, vi.DataCodeWordsPerBlockInGroup2) + for cw := 0; cw < int(vi.DataCodeWordsPerBlockInGroup2); cw++ { + blk.data[cw] = <-data + } + blk.ecc = ec.calcECC(blk.data, vi.ErrorCorrectionCodewordsPerBlock) + result[int(vi.NumberOfBlocksInGroup1)+b] = blk + } + + return result +} + +func (bl blockList) interleave(vi *versionInfo) []byte { + var maxCodewordCount int + if vi.DataCodeWordsPerBlockInGroup1 > vi.DataCodeWordsPerBlockInGroup2 { + maxCodewordCount = int(vi.DataCodeWordsPerBlockInGroup1) + } else { + maxCodewordCount = int(vi.DataCodeWordsPerBlockInGroup2) + } + resultLen := (vi.DataCodeWordsPerBlockInGroup1+vi.ErrorCorrectionCodewordsPerBlock)*vi.NumberOfBlocksInGroup1 + + (vi.DataCodeWordsPerBlockInGroup2+vi.ErrorCorrectionCodewordsPerBlock)*vi.NumberOfBlocksInGroup2 + + result := make([]byte, 0, resultLen) + for i := 0; i < maxCodewordCount; i++ { + for b := 0; b < len(bl); b++ { + if len(bl[b].data) > i { + result = append(result, bl[b].data[i]) + } + } + } + for i := 0; i < int(vi.ErrorCorrectionCodewordsPerBlock); i++ { + for b := 0; b < len(bl); b++ { + result = append(result, bl[b].ecc[i]) + } + } + return result +} diff --git a/vendor/github.com/boombuler/barcode/qr/encoder.go b/vendor/github.com/boombuler/barcode/qr/encoder.go new file mode 100644 index 000000000..2c6ab2111 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/encoder.go @@ -0,0 +1,416 @@ +// Package qr can be used to create QR barcodes. +package qr + +import ( + "image" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/utils" +) + +type encodeFn func(content string, eccLevel ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) + +// Encoding mode for QR Codes. +type Encoding byte + +const ( + // Auto will choose ths best matching encoding + Auto Encoding = iota + // Numeric encoding only encodes numbers [0-9] + Numeric + // AlphaNumeric encoding only encodes uppercase letters, numbers and [Space], $, %, *, +, -, ., /, : + AlphaNumeric + // Unicode encoding encodes the string as utf-8 + Unicode + // only for testing purpose + unknownEncoding +) + +func (e Encoding) getEncoder() encodeFn { + switch e { + case Auto: + return encodeAuto + case Numeric: + return encodeNumeric + case AlphaNumeric: + return encodeAlphaNumeric + case Unicode: + return encodeUnicode + } + return nil +} + +func (e Encoding) String() string { + switch e { + case Auto: + return "Auto" + case Numeric: + return "Numeric" + case AlphaNumeric: + return "AlphaNumeric" + case Unicode: + return "Unicode" + } + return "" +} + +// Encode returns a QR barcode with the given content, error correction level and uses the given encoding +func Encode(content string, level ErrorCorrectionLevel, mode Encoding) (barcode.Barcode, error) { + bits, vi, err := mode.getEncoder()(content, level) + if err != nil { + return nil, err + } + + blocks := splitToBlocks(bits.IterateBytes(), vi) + data := blocks.interleave(vi) + result := render(data, vi) + result.content = content + return result, nil +} + +func render(data []byte, vi *versionInfo) *qrcode { + dim := vi.modulWidth() + results := make([]*qrcode, 8) + for i := 0; i < 8; i++ { + results[i] = newBarcode(dim) + } + + occupied := newBarcode(dim) + + setAll := func(x int, y int, val bool) { + occupied.Set(x, y, true) + for i := 0; i < 8; i++ { + results[i].Set(x, y, val) + } + } + + drawFinderPatterns(vi, setAll) + drawAlignmentPatterns(occupied, vi, setAll) + + //Timing Pattern: + var i int + for i = 0; i < dim; i++ { + if !occupied.Get(i, 6) { + setAll(i, 6, i%2 == 0) + } + if !occupied.Get(6, i) { + setAll(6, i, i%2 == 0) + } + } + // Dark Module + setAll(8, dim-8, true) + + drawVersionInfo(vi, setAll) + drawFormatInfo(vi, -1, occupied.Set) + for i := 0; i < 8; i++ { + drawFormatInfo(vi, i, results[i].Set) + } + + // Write the data + var curBitNo int + + for pos := range iterateModules(occupied) { + var curBit bool + if curBitNo < len(data)*8 { + curBit = ((data[curBitNo/8] >> uint(7-(curBitNo%8))) & 1) == 1 + } else { + curBit = false + } + + for i := 0; i < 8; i++ { + setMasked(pos.X, pos.Y, curBit, i, results[i].Set) + } + curBitNo++ + } + + lowestPenalty := ^uint(0) + lowestPenaltyIdx := -1 + for i := 0; i < 8; i++ { + p := results[i].calcPenalty() + if p < lowestPenalty { + lowestPenalty = p + lowestPenaltyIdx = i + } + } + return results[lowestPenaltyIdx] +} + +func setMasked(x, y int, val bool, mask int, set func(int, int, bool)) { + switch mask { + case 0: + val = val != (((y + x) % 2) == 0) + break + case 1: + val = val != ((y % 2) == 0) + break + case 2: + val = val != ((x % 3) == 0) + break + case 3: + val = val != (((y + x) % 3) == 0) + break + case 4: + val = val != (((y/2 + x/3) % 2) == 0) + break + case 5: + val = val != (((y*x)%2)+((y*x)%3) == 0) + break + case 6: + val = val != ((((y*x)%2)+((y*x)%3))%2 == 0) + break + case 7: + val = val != ((((y+x)%2)+((y*x)%3))%2 == 0) + } + set(x, y, val) +} + +func iterateModules(occupied *qrcode) <-chan image.Point { + result := make(chan image.Point) + allPoints := make(chan image.Point) + go func() { + curX := occupied.dimension - 1 + curY := occupied.dimension - 1 + isUpward := true + + for true { + if isUpward { + allPoints <- image.Pt(curX, curY) + allPoints <- image.Pt(curX-1, curY) + curY-- + if curY < 0 { + curY = 0 + curX -= 2 + if curX == 6 { + curX-- + } + if curX < 0 { + break + } + isUpward = false + } + } else { + allPoints <- image.Pt(curX, curY) + allPoints <- image.Pt(curX-1, curY) + curY++ + if curY >= occupied.dimension { + curY = occupied.dimension - 1 + curX -= 2 + if curX == 6 { + curX-- + } + isUpward = true + if curX < 0 { + break + } + } + } + } + + close(allPoints) + }() + go func() { + for pt := range allPoints { + if !occupied.Get(pt.X, pt.Y) { + result <- pt + } + } + close(result) + }() + return result +} + +func drawFinderPatterns(vi *versionInfo, set func(int, int, bool)) { + dim := vi.modulWidth() + drawPattern := func(xoff int, yoff int) { + for x := -1; x < 8; x++ { + for y := -1; y < 8; y++ { + val := (x == 0 || x == 6 || y == 0 || y == 6 || (x > 1 && x < 5 && y > 1 && y < 5)) && (x <= 6 && y <= 6 && x >= 0 && y >= 0) + + if x+xoff >= 0 && x+xoff < dim && y+yoff >= 0 && y+yoff < dim { + set(x+xoff, y+yoff, val) + } + } + } + } + drawPattern(0, 0) + drawPattern(0, dim-7) + drawPattern(dim-7, 0) +} + +func drawAlignmentPatterns(occupied *qrcode, vi *versionInfo, set func(int, int, bool)) { + drawPattern := func(xoff int, yoff int) { + for x := -2; x <= 2; x++ { + for y := -2; y <= 2; y++ { + val := x == -2 || x == 2 || y == -2 || y == 2 || (x == 0 && y == 0) + set(x+xoff, y+yoff, val) + } + } + } + positions := vi.alignmentPatternPlacements() + + for _, x := range positions { + for _, y := range positions { + if occupied.Get(x, y) { + continue + } + drawPattern(x, y) + } + } +} + +var formatInfos = map[ErrorCorrectionLevel]map[int][]bool{ + L: { + 0: []bool{true, true, true, false, true, true, true, true, true, false, false, false, true, false, false}, + 1: []bool{true, true, true, false, false, true, false, true, true, true, true, false, false, true, true}, + 2: []bool{true, true, true, true, true, false, true, true, false, true, false, true, false, true, false}, + 3: []bool{true, true, true, true, false, false, false, true, false, false, true, true, true, false, true}, + 4: []bool{true, true, false, false, true, true, false, false, false, true, false, true, true, true, true}, + 5: []bool{true, true, false, false, false, true, true, false, false, false, true, true, false, false, false}, + 6: []bool{true, true, false, true, true, false, false, false, true, false, false, false, false, false, true}, + 7: []bool{true, true, false, true, false, false, true, false, true, true, true, false, true, true, false}, + }, + M: { + 0: []bool{true, false, true, false, true, false, false, false, false, false, true, false, false, true, false}, + 1: []bool{true, false, true, false, false, false, true, false, false, true, false, false, true, false, true}, + 2: []bool{true, false, true, true, true, true, false, false, true, true, true, true, true, false, false}, + 3: []bool{true, false, true, true, false, true, true, false, true, false, false, true, false, true, true}, + 4: []bool{true, false, false, false, true, false, true, true, true, true, true, true, false, false, true}, + 5: []bool{true, false, false, false, false, false, false, true, true, false, false, true, true, true, false}, + 6: []bool{true, false, false, true, true, true, true, true, false, false, true, false, true, true, true}, + 7: []bool{true, false, false, true, false, true, false, true, false, true, false, false, false, false, false}, + }, + Q: { + 0: []bool{false, true, true, false, true, false, true, false, true, false, true, true, true, true, true}, + 1: []bool{false, true, true, false, false, false, false, false, true, true, false, true, false, false, false}, + 2: []bool{false, true, true, true, true, true, true, false, false, true, true, false, false, false, true}, + 3: []bool{false, true, true, true, false, true, false, false, false, false, false, false, true, true, false}, + 4: []bool{false, true, false, false, true, false, false, true, false, true, true, false, true, false, false}, + 5: []bool{false, true, false, false, false, false, true, true, false, false, false, false, false, true, true}, + 6: []bool{false, true, false, true, true, true, false, true, true, false, true, true, false, true, false}, + 7: []bool{false, true, false, true, false, true, true, true, true, true, false, true, true, false, true}, + }, + H: { + 0: []bool{false, false, true, false, true, true, false, true, false, false, false, true, false, false, true}, + 1: []bool{false, false, true, false, false, true, true, true, false, true, true, true, true, true, false}, + 2: []bool{false, false, true, true, true, false, false, true, true, true, false, false, true, true, true}, + 3: []bool{false, false, true, true, false, false, true, true, true, false, true, false, false, false, false}, + 4: []bool{false, false, false, false, true, true, true, false, true, true, false, false, false, true, false}, + 5: []bool{false, false, false, false, false, true, false, false, true, false, true, false, true, false, true}, + 6: []bool{false, false, false, true, true, false, true, false, false, false, false, true, true, false, false}, + 7: []bool{false, false, false, true, false, false, false, false, false, true, true, true, false, true, true}, + }, +} + +func drawFormatInfo(vi *versionInfo, usedMask int, set func(int, int, bool)) { + var formatInfo []bool + + if usedMask == -1 { + formatInfo = []bool{true, true, true, true, true, true, true, true, true, true, true, true, true, true, true} // Set all to true cause -1 --> occupied mask. + } else { + formatInfo = formatInfos[vi.Level][usedMask] + } + + if len(formatInfo) == 15 { + dim := vi.modulWidth() + set(0, 8, formatInfo[0]) + set(1, 8, formatInfo[1]) + set(2, 8, formatInfo[2]) + set(3, 8, formatInfo[3]) + set(4, 8, formatInfo[4]) + set(5, 8, formatInfo[5]) + set(7, 8, formatInfo[6]) + set(8, 8, formatInfo[7]) + set(8, 7, formatInfo[8]) + set(8, 5, formatInfo[9]) + set(8, 4, formatInfo[10]) + set(8, 3, formatInfo[11]) + set(8, 2, formatInfo[12]) + set(8, 1, formatInfo[13]) + set(8, 0, formatInfo[14]) + + set(8, dim-1, formatInfo[0]) + set(8, dim-2, formatInfo[1]) + set(8, dim-3, formatInfo[2]) + set(8, dim-4, formatInfo[3]) + set(8, dim-5, formatInfo[4]) + set(8, dim-6, formatInfo[5]) + set(8, dim-7, formatInfo[6]) + set(dim-8, 8, formatInfo[7]) + set(dim-7, 8, formatInfo[8]) + set(dim-6, 8, formatInfo[9]) + set(dim-5, 8, formatInfo[10]) + set(dim-4, 8, formatInfo[11]) + set(dim-3, 8, formatInfo[12]) + set(dim-2, 8, formatInfo[13]) + set(dim-1, 8, formatInfo[14]) + } +} + +var versionInfoBitsByVersion = map[byte][]bool{ + 7: []bool{false, false, false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false}, + 8: []bool{false, false, true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false}, + 9: []bool{false, false, true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true}, + 10: []bool{false, false, true, false, true, false, false, true, false, false, true, true, false, true, false, false, true, true}, + 11: []bool{false, false, true, false, true, true, true, false, true, true, true, true, true, true, false, true, true, false}, + 12: []bool{false, false, true, true, false, false, false, true, true, true, false, true, true, false, false, false, true, false}, + 13: []bool{false, false, true, true, false, true, true, false, false, false, false, true, false, false, false, true, true, true}, + 14: []bool{false, false, true, true, true, false, false, true, true, false, false, false, false, false, true, true, false, true}, + 15: []bool{false, false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false, false}, + 16: []bool{false, true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false, false}, + 17: []bool{false, true, false, false, false, true, false, true, false, false, false, true, false, true, true, true, false, true}, + 18: []bool{false, true, false, false, true, false, true, false, true, false, false, false, false, true, false, true, true, true}, + 19: []bool{false, true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true, false}, + 20: []bool{false, true, false, true, false, false, true, false, false, true, true, false, true, false, false, true, true, false}, + 21: []bool{false, true, false, true, false, true, false, true, true, false, true, false, false, false, false, false, true, true}, + 22: []bool{false, true, false, true, true, false, true, false, false, false, true, true, false, false, true, false, false, true}, + 23: []bool{false, true, false, true, true, true, false, true, true, true, true, true, true, false, true, true, false, false}, + 24: []bool{false, true, true, false, false, false, true, true, true, false, true, true, false, false, false, true, false, false}, + 25: []bool{false, true, true, false, false, true, false, false, false, true, true, true, true, false, false, false, false, true}, + 26: []bool{false, true, true, false, true, false, true, true, true, true, true, false, true, false, true, false, true, true}, + 27: []bool{false, true, true, false, true, true, false, false, false, false, true, false, false, false, true, true, true, false}, + 28: []bool{false, true, true, true, false, false, true, true, false, false, false, false, false, true, true, false, true, false}, + 29: []bool{false, true, true, true, false, true, false, false, true, true, false, false, true, true, true, true, true, true}, + 30: []bool{false, true, true, true, true, false, true, true, false, true, false, true, true, true, false, true, false, true}, + 31: []bool{false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false, false, false}, + 32: []bool{true, false, false, false, false, false, true, false, false, true, true, true, false, true, false, true, false, true}, + 33: []bool{true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false, false, false}, + 34: []bool{true, false, false, false, true, false, true, false, false, false, true, false, true, true, true, false, true, false}, + 35: []bool{true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, true, true, true}, + 36: []bool{true, false, false, true, false, false, true, false, true, true, false, false, false, false, true, false, true, true}, + 37: []bool{true, false, false, true, false, true, false, true, false, false, false, false, true, false, true, true, true, false}, + 38: []bool{true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true, false, false}, + 39: []bool{true, false, false, true, true, true, false, true, false, true, false, true, false, false, false, false, false, true}, + 40: []bool{true, false, true, false, false, false, true, true, false, false, false, true, true, false, true, false, false, true}, +} + +func drawVersionInfo(vi *versionInfo, set func(int, int, bool)) { + versionInfoBits, ok := versionInfoBitsByVersion[vi.Version] + + if ok && len(versionInfoBits) > 0 { + for i := 0; i < len(versionInfoBits); i++ { + x := (vi.modulWidth() - 11) + i%3 + y := i / 3 + set(x, y, versionInfoBits[len(versionInfoBits)-i-1]) + set(y, x, versionInfoBits[len(versionInfoBits)-i-1]) + } + } + +} + +func addPaddingAndTerminator(bl *utils.BitList, vi *versionInfo) { + for i := 0; i < 4 && bl.Len() < vi.totalDataBytes()*8; i++ { + bl.AddBit(false) + } + + for bl.Len()%8 != 0 { + bl.AddBit(false) + } + + for i := 0; bl.Len() < vi.totalDataBytes()*8; i++ { + if i%2 == 0 { + bl.AddByte(236) + } else { + bl.AddByte(17) + } + } +} diff --git a/vendor/github.com/boombuler/barcode/qr/errorcorrection.go b/vendor/github.com/boombuler/barcode/qr/errorcorrection.go new file mode 100644 index 000000000..08ebf0ce6 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/errorcorrection.go @@ -0,0 +1,29 @@ +package qr + +import ( + "github.com/boombuler/barcode/utils" +) + +type errorCorrection struct { + rs *utils.ReedSolomonEncoder +} + +var ec = newErrorCorrection() + +func newErrorCorrection() *errorCorrection { + fld := utils.NewGaloisField(285, 256, 0) + return &errorCorrection{utils.NewReedSolomonEncoder(fld)} +} + +func (ec *errorCorrection) calcECC(data []byte, eccCount byte) []byte { + dataInts := make([]int, len(data)) + for i := 0; i < len(data); i++ { + dataInts[i] = int(data[i]) + } + res := ec.rs.Encode(dataInts, int(eccCount)) + result := make([]byte, len(res)) + for i := 0; i < len(res); i++ { + result[i] = byte(res[i]) + } + return result +} diff --git a/vendor/github.com/boombuler/barcode/qr/numeric.go b/vendor/github.com/boombuler/barcode/qr/numeric.go new file mode 100644 index 000000000..49b44cc45 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/numeric.go @@ -0,0 +1,56 @@ +package qr + +import ( + "errors" + "fmt" + "strconv" + + "github.com/boombuler/barcode/utils" +) + +func encodeNumeric(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) { + contentBitCount := (len(content) / 3) * 10 + switch len(content) % 3 { + case 1: + contentBitCount += 4 + case 2: + contentBitCount += 7 + } + vi := findSmallestVersionInfo(ecl, numericMode, contentBitCount) + if vi == nil { + return nil, nil, errors.New("To much data to encode") + } + res := new(utils.BitList) + res.AddBits(int(numericMode), 4) + res.AddBits(len(content), vi.charCountBits(numericMode)) + + for pos := 0; pos < len(content); pos += 3 { + var curStr string + if pos+3 <= len(content) { + curStr = content[pos : pos+3] + } else { + curStr = content[pos:] + } + + i, err := strconv.Atoi(curStr) + if err != nil || i < 0 { + return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, Numeric) + } + var bitCnt byte + switch len(curStr) % 3 { + case 0: + bitCnt = 10 + case 1: + bitCnt = 4 + break + case 2: + bitCnt = 7 + break + } + + res.AddBits(i, bitCnt) + } + + addPaddingAndTerminator(res, vi) + return res, vi, nil +} diff --git a/vendor/github.com/boombuler/barcode/qr/qrcode.go b/vendor/github.com/boombuler/barcode/qr/qrcode.go new file mode 100644 index 000000000..b7ac26d74 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/qrcode.go @@ -0,0 +1,166 @@ +package qr + +import ( + "image" + "image/color" + "math" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/utils" +) + +type qrcode struct { + dimension int + data *utils.BitList + content string +} + +func (qr *qrcode) Content() string { + return qr.content +} + +func (qr *qrcode) Metadata() barcode.Metadata { + return barcode.Metadata{"QR Code", 2} +} + +func (qr *qrcode) ColorModel() color.Model { + return color.Gray16Model +} + +func (qr *qrcode) Bounds() image.Rectangle { + return image.Rect(0, 0, qr.dimension, qr.dimension) +} + +func (qr *qrcode) At(x, y int) color.Color { + if qr.Get(x, y) { + return color.Black + } + return color.White +} + +func (qr *qrcode) Get(x, y int) bool { + return qr.data.GetBit(x*qr.dimension + y) +} + +func (qr *qrcode) Set(x, y int, val bool) { + qr.data.SetBit(x*qr.dimension+y, val) +} + +func (qr *qrcode) calcPenalty() uint { + return qr.calcPenaltyRule1() + qr.calcPenaltyRule2() + qr.calcPenaltyRule3() + qr.calcPenaltyRule4() +} + +func (qr *qrcode) calcPenaltyRule1() uint { + var result uint + for x := 0; x < qr.dimension; x++ { + checkForX := false + var cntX uint + checkForY := false + var cntY uint + + for y := 0; y < qr.dimension; y++ { + if qr.Get(x, y) == checkForX { + cntX++ + } else { + checkForX = !checkForX + if cntX >= 5 { + result += cntX - 2 + } + cntX = 1 + } + + if qr.Get(y, x) == checkForY { + cntY++ + } else { + checkForY = !checkForY + if cntY >= 5 { + result += cntY - 2 + } + cntY = 1 + } + } + + if cntX >= 5 { + result += cntX - 2 + } + if cntY >= 5 { + result += cntY - 2 + } + } + + return result +} + +func (qr *qrcode) calcPenaltyRule2() uint { + var result uint + for x := 0; x < qr.dimension-1; x++ { + for y := 0; y < qr.dimension-1; y++ { + check := qr.Get(x, y) + if qr.Get(x, y+1) == check && qr.Get(x+1, y) == check && qr.Get(x+1, y+1) == check { + result += 3 + } + } + } + return result +} + +func (qr *qrcode) calcPenaltyRule3() uint { + pattern1 := []bool{true, false, true, true, true, false, true, false, false, false, false} + pattern2 := []bool{false, false, false, false, true, false, true, true, true, false, true} + + var result uint + for x := 0; x <= qr.dimension-len(pattern1); x++ { + for y := 0; y < qr.dimension; y++ { + pattern1XFound := true + pattern2XFound := true + pattern1YFound := true + pattern2YFound := true + + for i := 0; i < len(pattern1); i++ { + iv := qr.Get(x+i, y) + if iv != pattern1[i] { + pattern1XFound = false + } + if iv != pattern2[i] { + pattern2XFound = false + } + iv = qr.Get(y, x+i) + if iv != pattern1[i] { + pattern1YFound = false + } + if iv != pattern2[i] { + pattern2YFound = false + } + } + if pattern1XFound || pattern2XFound { + result += 40 + } + if pattern1YFound || pattern2YFound { + result += 40 + } + } + } + + return result +} + +func (qr *qrcode) calcPenaltyRule4() uint { + totalNum := qr.data.Len() + trueCnt := 0 + for i := 0; i < totalNum; i++ { + if qr.data.GetBit(i) { + trueCnt++ + } + } + percDark := float64(trueCnt) * 100 / float64(totalNum) + floor := math.Abs(math.Floor(percDark/5) - 10) + ceil := math.Abs(math.Ceil(percDark/5) - 10) + return uint(math.Min(floor, ceil) * 10) +} + +func newBarcode(dim int) *qrcode { + res := new(qrcode) + res.dimension = dim + res.data = utils.NewBitList(dim * dim) + return res +} diff --git a/vendor/github.com/boombuler/barcode/qr/unicode.go b/vendor/github.com/boombuler/barcode/qr/unicode.go new file mode 100644 index 000000000..a9135ab6d --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/unicode.go @@ -0,0 +1,27 @@ +package qr + +import ( + "errors" + + "github.com/boombuler/barcode/utils" +) + +func encodeUnicode(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) { + data := []byte(content) + + vi := findSmallestVersionInfo(ecl, byteMode, len(data)*8) + if vi == nil { + return nil, nil, errors.New("To much data to encode") + } + + // It's not correct to add the unicode bytes to the result directly but most readers can't handle the + // required ECI header... + res := new(utils.BitList) + res.AddBits(int(byteMode), 4) + res.AddBits(len(content), vi.charCountBits(byteMode)) + for _, b := range data { + res.AddByte(b) + } + addPaddingAndTerminator(res, vi) + return res, vi, nil +} diff --git a/vendor/github.com/boombuler/barcode/qr/versioninfo.go b/vendor/github.com/boombuler/barcode/qr/versioninfo.go new file mode 100644 index 000000000..6852a5766 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/qr/versioninfo.go @@ -0,0 +1,310 @@ +package qr + +import "math" + +// ErrorCorrectionLevel indicates the amount of "backup data" stored in the QR code +type ErrorCorrectionLevel byte + +const ( + // L recovers 7% of data + L ErrorCorrectionLevel = iota + // M recovers 15% of data + M + // Q recovers 25% of data + Q + // H recovers 30% of data + H +) + +func (ecl ErrorCorrectionLevel) String() string { + switch ecl { + case L: + return "L" + case M: + return "M" + case Q: + return "Q" + case H: + return "H" + } + return "unknown" +} + +type encodingMode byte + +const ( + numericMode encodingMode = 1 + alphaNumericMode encodingMode = 2 + byteMode encodingMode = 4 + kanjiMode encodingMode = 8 +) + +type versionInfo struct { + Version byte + Level ErrorCorrectionLevel + ErrorCorrectionCodewordsPerBlock byte + NumberOfBlocksInGroup1 byte + DataCodeWordsPerBlockInGroup1 byte + NumberOfBlocksInGroup2 byte + DataCodeWordsPerBlockInGroup2 byte +} + +var versionInfos = []*versionInfo{ + &versionInfo{1, L, 7, 1, 19, 0, 0}, + &versionInfo{1, M, 10, 1, 16, 0, 0}, + &versionInfo{1, Q, 13, 1, 13, 0, 0}, + &versionInfo{1, H, 17, 1, 9, 0, 0}, + &versionInfo{2, L, 10, 1, 34, 0, 0}, + &versionInfo{2, M, 16, 1, 28, 0, 0}, + &versionInfo{2, Q, 22, 1, 22, 0, 0}, + &versionInfo{2, H, 28, 1, 16, 0, 0}, + &versionInfo{3, L, 15, 1, 55, 0, 0}, + &versionInfo{3, M, 26, 1, 44, 0, 0}, + &versionInfo{3, Q, 18, 2, 17, 0, 0}, + &versionInfo{3, H, 22, 2, 13, 0, 0}, + &versionInfo{4, L, 20, 1, 80, 0, 0}, + &versionInfo{4, M, 18, 2, 32, 0, 0}, + &versionInfo{4, Q, 26, 2, 24, 0, 0}, + &versionInfo{4, H, 16, 4, 9, 0, 0}, + &versionInfo{5, L, 26, 1, 108, 0, 0}, + &versionInfo{5, M, 24, 2, 43, 0, 0}, + &versionInfo{5, Q, 18, 2, 15, 2, 16}, + &versionInfo{5, H, 22, 2, 11, 2, 12}, + &versionInfo{6, L, 18, 2, 68, 0, 0}, + &versionInfo{6, M, 16, 4, 27, 0, 0}, + &versionInfo{6, Q, 24, 4, 19, 0, 0}, + &versionInfo{6, H, 28, 4, 15, 0, 0}, + &versionInfo{7, L, 20, 2, 78, 0, 0}, + &versionInfo{7, M, 18, 4, 31, 0, 0}, + &versionInfo{7, Q, 18, 2, 14, 4, 15}, + &versionInfo{7, H, 26, 4, 13, 1, 14}, + &versionInfo{8, L, 24, 2, 97, 0, 0}, + &versionInfo{8, M, 22, 2, 38, 2, 39}, + &versionInfo{8, Q, 22, 4, 18, 2, 19}, + &versionInfo{8, H, 26, 4, 14, 2, 15}, + &versionInfo{9, L, 30, 2, 116, 0, 0}, + &versionInfo{9, M, 22, 3, 36, 2, 37}, + &versionInfo{9, Q, 20, 4, 16, 4, 17}, + &versionInfo{9, H, 24, 4, 12, 4, 13}, + &versionInfo{10, L, 18, 2, 68, 2, 69}, + &versionInfo{10, M, 26, 4, 43, 1, 44}, + &versionInfo{10, Q, 24, 6, 19, 2, 20}, + &versionInfo{10, H, 28, 6, 15, 2, 16}, + &versionInfo{11, L, 20, 4, 81, 0, 0}, + &versionInfo{11, M, 30, 1, 50, 4, 51}, + &versionInfo{11, Q, 28, 4, 22, 4, 23}, + &versionInfo{11, H, 24, 3, 12, 8, 13}, + &versionInfo{12, L, 24, 2, 92, 2, 93}, + &versionInfo{12, M, 22, 6, 36, 2, 37}, + &versionInfo{12, Q, 26, 4, 20, 6, 21}, + &versionInfo{12, H, 28, 7, 14, 4, 15}, + &versionInfo{13, L, 26, 4, 107, 0, 0}, + &versionInfo{13, M, 22, 8, 37, 1, 38}, + &versionInfo{13, Q, 24, 8, 20, 4, 21}, + &versionInfo{13, H, 22, 12, 11, 4, 12}, + &versionInfo{14, L, 30, 3, 115, 1, 116}, + &versionInfo{14, M, 24, 4, 40, 5, 41}, + &versionInfo{14, Q, 20, 11, 16, 5, 17}, + &versionInfo{14, H, 24, 11, 12, 5, 13}, + &versionInfo{15, L, 22, 5, 87, 1, 88}, + &versionInfo{15, M, 24, 5, 41, 5, 42}, + &versionInfo{15, Q, 30, 5, 24, 7, 25}, + &versionInfo{15, H, 24, 11, 12, 7, 13}, + &versionInfo{16, L, 24, 5, 98, 1, 99}, + &versionInfo{16, M, 28, 7, 45, 3, 46}, + &versionInfo{16, Q, 24, 15, 19, 2, 20}, + &versionInfo{16, H, 30, 3, 15, 13, 16}, + &versionInfo{17, L, 28, 1, 107, 5, 108}, + &versionInfo{17, M, 28, 10, 46, 1, 47}, + &versionInfo{17, Q, 28, 1, 22, 15, 23}, + &versionInfo{17, H, 28, 2, 14, 17, 15}, + &versionInfo{18, L, 30, 5, 120, 1, 121}, + &versionInfo{18, M, 26, 9, 43, 4, 44}, + &versionInfo{18, Q, 28, 17, 22, 1, 23}, + &versionInfo{18, H, 28, 2, 14, 19, 15}, + &versionInfo{19, L, 28, 3, 113, 4, 114}, + &versionInfo{19, M, 26, 3, 44, 11, 45}, + &versionInfo{19, Q, 26, 17, 21, 4, 22}, + &versionInfo{19, H, 26, 9, 13, 16, 14}, + &versionInfo{20, L, 28, 3, 107, 5, 108}, + &versionInfo{20, M, 26, 3, 41, 13, 42}, + &versionInfo{20, Q, 30, 15, 24, 5, 25}, + &versionInfo{20, H, 28, 15, 15, 10, 16}, + &versionInfo{21, L, 28, 4, 116, 4, 117}, + &versionInfo{21, M, 26, 17, 42, 0, 0}, + &versionInfo{21, Q, 28, 17, 22, 6, 23}, + &versionInfo{21, H, 30, 19, 16, 6, 17}, + &versionInfo{22, L, 28, 2, 111, 7, 112}, + &versionInfo{22, M, 28, 17, 46, 0, 0}, + &versionInfo{22, Q, 30, 7, 24, 16, 25}, + &versionInfo{22, H, 24, 34, 13, 0, 0}, + &versionInfo{23, L, 30, 4, 121, 5, 122}, + &versionInfo{23, M, 28, 4, 47, 14, 48}, + &versionInfo{23, Q, 30, 11, 24, 14, 25}, + &versionInfo{23, H, 30, 16, 15, 14, 16}, + &versionInfo{24, L, 30, 6, 117, 4, 118}, + &versionInfo{24, M, 28, 6, 45, 14, 46}, + &versionInfo{24, Q, 30, 11, 24, 16, 25}, + &versionInfo{24, H, 30, 30, 16, 2, 17}, + &versionInfo{25, L, 26, 8, 106, 4, 107}, + &versionInfo{25, M, 28, 8, 47, 13, 48}, + &versionInfo{25, Q, 30, 7, 24, 22, 25}, + &versionInfo{25, H, 30, 22, 15, 13, 16}, + &versionInfo{26, L, 28, 10, 114, 2, 115}, + &versionInfo{26, M, 28, 19, 46, 4, 47}, + &versionInfo{26, Q, 28, 28, 22, 6, 23}, + &versionInfo{26, H, 30, 33, 16, 4, 17}, + &versionInfo{27, L, 30, 8, 122, 4, 123}, + &versionInfo{27, M, 28, 22, 45, 3, 46}, + &versionInfo{27, Q, 30, 8, 23, 26, 24}, + &versionInfo{27, H, 30, 12, 15, 28, 16}, + &versionInfo{28, L, 30, 3, 117, 10, 118}, + &versionInfo{28, M, 28, 3, 45, 23, 46}, + &versionInfo{28, Q, 30, 4, 24, 31, 25}, + &versionInfo{28, H, 30, 11, 15, 31, 16}, + &versionInfo{29, L, 30, 7, 116, 7, 117}, + &versionInfo{29, M, 28, 21, 45, 7, 46}, + &versionInfo{29, Q, 30, 1, 23, 37, 24}, + &versionInfo{29, H, 30, 19, 15, 26, 16}, + &versionInfo{30, L, 30, 5, 115, 10, 116}, + &versionInfo{30, M, 28, 19, 47, 10, 48}, + &versionInfo{30, Q, 30, 15, 24, 25, 25}, + &versionInfo{30, H, 30, 23, 15, 25, 16}, + &versionInfo{31, L, 30, 13, 115, 3, 116}, + &versionInfo{31, M, 28, 2, 46, 29, 47}, + &versionInfo{31, Q, 30, 42, 24, 1, 25}, + &versionInfo{31, H, 30, 23, 15, 28, 16}, + &versionInfo{32, L, 30, 17, 115, 0, 0}, + &versionInfo{32, M, 28, 10, 46, 23, 47}, + &versionInfo{32, Q, 30, 10, 24, 35, 25}, + &versionInfo{32, H, 30, 19, 15, 35, 16}, + &versionInfo{33, L, 30, 17, 115, 1, 116}, + &versionInfo{33, M, 28, 14, 46, 21, 47}, + &versionInfo{33, Q, 30, 29, 24, 19, 25}, + &versionInfo{33, H, 30, 11, 15, 46, 16}, + &versionInfo{34, L, 30, 13, 115, 6, 116}, + &versionInfo{34, M, 28, 14, 46, 23, 47}, + &versionInfo{34, Q, 30, 44, 24, 7, 25}, + &versionInfo{34, H, 30, 59, 16, 1, 17}, + &versionInfo{35, L, 30, 12, 121, 7, 122}, + &versionInfo{35, M, 28, 12, 47, 26, 48}, + &versionInfo{35, Q, 30, 39, 24, 14, 25}, + &versionInfo{35, H, 30, 22, 15, 41, 16}, + &versionInfo{36, L, 30, 6, 121, 14, 122}, + &versionInfo{36, M, 28, 6, 47, 34, 48}, + &versionInfo{36, Q, 30, 46, 24, 10, 25}, + &versionInfo{36, H, 30, 2, 15, 64, 16}, + &versionInfo{37, L, 30, 17, 122, 4, 123}, + &versionInfo{37, M, 28, 29, 46, 14, 47}, + &versionInfo{37, Q, 30, 49, 24, 10, 25}, + &versionInfo{37, H, 30, 24, 15, 46, 16}, + &versionInfo{38, L, 30, 4, 122, 18, 123}, + &versionInfo{38, M, 28, 13, 46, 32, 47}, + &versionInfo{38, Q, 30, 48, 24, 14, 25}, + &versionInfo{38, H, 30, 42, 15, 32, 16}, + &versionInfo{39, L, 30, 20, 117, 4, 118}, + &versionInfo{39, M, 28, 40, 47, 7, 48}, + &versionInfo{39, Q, 30, 43, 24, 22, 25}, + &versionInfo{39, H, 30, 10, 15, 67, 16}, + &versionInfo{40, L, 30, 19, 118, 6, 119}, + &versionInfo{40, M, 28, 18, 47, 31, 48}, + &versionInfo{40, Q, 30, 34, 24, 34, 25}, + &versionInfo{40, H, 30, 20, 15, 61, 16}, +} + +func (vi *versionInfo) totalDataBytes() int { + g1Data := int(vi.NumberOfBlocksInGroup1) * int(vi.DataCodeWordsPerBlockInGroup1) + g2Data := int(vi.NumberOfBlocksInGroup2) * int(vi.DataCodeWordsPerBlockInGroup2) + return (g1Data + g2Data) +} + +func (vi *versionInfo) charCountBits(m encodingMode) byte { + switch m { + case numericMode: + if vi.Version < 10 { + return 10 + } else if vi.Version < 27 { + return 12 + } + return 14 + + case alphaNumericMode: + if vi.Version < 10 { + return 9 + } else if vi.Version < 27 { + return 11 + } + return 13 + + case byteMode: + if vi.Version < 10 { + return 8 + } + return 16 + + case kanjiMode: + if vi.Version < 10 { + return 8 + } else if vi.Version < 27 { + return 10 + } + return 12 + default: + return 0 + } +} + +func (vi *versionInfo) modulWidth() int { + return ((int(vi.Version) - 1) * 4) + 21 +} + +func (vi *versionInfo) alignmentPatternPlacements() []int { + if vi.Version == 1 { + return make([]int, 0) + } + + first := 6 + last := vi.modulWidth() - 7 + space := float64(last - first) + count := int(math.Ceil(space/28)) + 1 + + result := make([]int, count) + result[0] = first + result[len(result)-1] = last + if count > 2 { + step := int(math.Ceil(float64(last-first) / float64(count-1))) + if step%2 == 1 { + frac := float64(last-first) / float64(count-1) + _, x := math.Modf(frac) + if x >= 0.5 { + frac = math.Ceil(frac) + } else { + frac = math.Floor(frac) + } + + if int(frac)%2 == 0 { + step-- + } else { + step++ + } + } + + for i := 1; i <= count-2; i++ { + result[i] = last - (step * (count - 1 - i)) + } + } + + return result +} + +func findSmallestVersionInfo(ecl ErrorCorrectionLevel, mode encodingMode, dataBits int) *versionInfo { + dataBits = dataBits + 4 // mode indicator + for _, vi := range versionInfos { + if vi.Level == ecl { + if (vi.totalDataBytes() * 8) >= (dataBits + int(vi.charCountBits(mode))) { + return vi + } + } + } + return nil +} diff --git a/vendor/github.com/boombuler/barcode/scaledbarcode.go b/vendor/github.com/boombuler/barcode/scaledbarcode.go new file mode 100644 index 000000000..152b18017 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/scaledbarcode.go @@ -0,0 +1,134 @@ +package barcode + +import ( + "errors" + "fmt" + "image" + "image/color" + "math" +) + +type wrapFunc func(x, y int) color.Color + +type scaledBarcode struct { + wrapped Barcode + wrapperFunc wrapFunc + rect image.Rectangle +} + +type intCSscaledBC struct { + scaledBarcode +} + +func (bc *scaledBarcode) Content() string { + return bc.wrapped.Content() +} + +func (bc *scaledBarcode) Metadata() Metadata { + return bc.wrapped.Metadata() +} + +func (bc *scaledBarcode) ColorModel() color.Model { + return bc.wrapped.ColorModel() +} + +func (bc *scaledBarcode) Bounds() image.Rectangle { + return bc.rect +} + +func (bc *scaledBarcode) At(x, y int) color.Color { + return bc.wrapperFunc(x, y) +} + +func (bc *intCSscaledBC) CheckSum() int { + if cs, ok := bc.wrapped.(BarcodeIntCS); ok { + return cs.CheckSum() + } + return 0 +} + +// Scale returns a resized barcode with the given width and height. +func Scale(bc Barcode, width, height int) (Barcode, error) { + switch bc.Metadata().Dimensions { + case 1: + return scale1DCode(bc, width, height) + case 2: + return scale2DCode(bc, width, height) + } + + return nil, errors.New("unsupported barcode format") +} + +func newScaledBC(wrapped Barcode, wrapperFunc wrapFunc, rect image.Rectangle) Barcode { + result := &scaledBarcode{ + wrapped: wrapped, + wrapperFunc: wrapperFunc, + rect: rect, + } + + if _, ok := wrapped.(BarcodeIntCS); ok { + return &intCSscaledBC{*result} + } + return result +} + +func scale2DCode(bc Barcode, width, height int) (Barcode, error) { + orgBounds := bc.Bounds() + orgWidth := orgBounds.Max.X - orgBounds.Min.X + orgHeight := orgBounds.Max.Y - orgBounds.Min.Y + + factor := int(math.Min(float64(width)/float64(orgWidth), float64(height)/float64(orgHeight))) + if factor <= 0 { + return nil, fmt.Errorf("can not scale barcode to an image smaller than %dx%d", orgWidth, orgHeight) + } + + offsetX := (width - (orgWidth * factor)) / 2 + offsetY := (height - (orgHeight * factor)) / 2 + + wrap := func(x, y int) color.Color { + if x < offsetX || y < offsetY { + return color.White + } + x = (x - offsetX) / factor + y = (y - offsetY) / factor + if x >= orgWidth || y >= orgHeight { + return color.White + } + return bc.At(x, y) + } + + return newScaledBC( + bc, + wrap, + image.Rect(0, 0, width, height), + ), nil +} + +func scale1DCode(bc Barcode, width, height int) (Barcode, error) { + orgBounds := bc.Bounds() + orgWidth := orgBounds.Max.X - orgBounds.Min.X + factor := int(float64(width) / float64(orgWidth)) + + if factor <= 0 { + return nil, fmt.Errorf("can not scale barcode to an image smaller than %dx1", orgWidth) + } + offsetX := (width - (orgWidth * factor)) / 2 + + wrap := func(x, y int) color.Color { + if x < offsetX { + return color.White + } + x = (x - offsetX) / factor + + if x >= orgWidth { + return color.White + } + return bc.At(x, 0) + } + + return newScaledBC( + bc, + wrap, + image.Rect(0, 0, width, height), + ), nil +} diff --git a/vendor/github.com/boombuler/barcode/utils/base1dcode.go b/vendor/github.com/boombuler/barcode/utils/base1dcode.go new file mode 100644 index 000000000..75e50048c --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/base1dcode.go @@ -0,0 +1,57 @@ +// Package utils contain some utilities which are needed to create barcodes +package utils + +import ( + "image" + "image/color" + + "github.com/boombuler/barcode" +) + +type base1DCode struct { + *BitList + kind string + content string +} + +type base1DCodeIntCS struct { + base1DCode + checksum int +} + +func (c *base1DCode) Content() string { + return c.content +} + +func (c *base1DCode) Metadata() barcode.Metadata { + return barcode.Metadata{c.kind, 1} +} + +func (c *base1DCode) ColorModel() color.Model { + return color.Gray16Model +} + +func (c *base1DCode) Bounds() image.Rectangle { + return image.Rect(0, 0, c.Len(), 1) +} + +func (c *base1DCode) At(x, y int) color.Color { + if c.GetBit(x) { + return color.Black + } + return color.White +} + +func (c *base1DCodeIntCS) CheckSum() int { + return c.checksum +} + +// New1DCode creates a new 1D barcode where the bars are represented by the bits in the bars BitList +func New1DCodeIntCheckSum(codeKind, content string, bars *BitList, checksum int) barcode.BarcodeIntCS { + return &base1DCodeIntCS{base1DCode{bars, codeKind, content}, checksum} +} + +// New1DCode creates a new 1D barcode where the bars are represented by the bits in the bars BitList +func New1DCode(codeKind, content string, bars *BitList) barcode.Barcode { + return &base1DCode{bars, codeKind, content} +} diff --git a/vendor/github.com/boombuler/barcode/utils/bitlist.go b/vendor/github.com/boombuler/barcode/utils/bitlist.go new file mode 100644 index 000000000..bb05e53b5 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/bitlist.go @@ -0,0 +1,119 @@ +package utils + +// BitList is a list that contains bits +type BitList struct { + count int + data []int32 +} + +// NewBitList returns a new BitList with the given length +// all bits are initialize with false +func NewBitList(capacity int) *BitList { + bl := new(BitList) + bl.count = capacity + x := 0 + if capacity%32 != 0 { + x = 1 + } + bl.data = make([]int32, capacity/32+x) + return bl +} + +// Len returns the number of contained bits +func (bl *BitList) Len() int { + return bl.count +} + +func (bl *BitList) grow() { + growBy := len(bl.data) + if growBy < 128 { + growBy = 128 + } else if growBy >= 1024 { + growBy = 1024 + } + + nd := make([]int32, len(bl.data)+growBy) + copy(nd, bl.data) + bl.data = nd +} + +// AddBit appends the given bits to the end of the list +func (bl *BitList) AddBit(bits ...bool) { + for _, bit := range bits { + itmIndex := bl.count / 32 + for itmIndex >= len(bl.data) { + bl.grow() + } + bl.SetBit(bl.count, bit) + bl.count++ + } +} + +// SetBit sets the bit at the given index to the given value +func (bl *BitList) SetBit(index int, value bool) { + itmIndex := index / 32 + itmBitShift := 31 - (index % 32) + if value { + bl.data[itmIndex] = bl.data[itmIndex] | 1<> uint(itmBitShift)) & 1) == 1 +} + +// AddByte appends all 8 bits of the given byte to the end of the list +func (bl *BitList) AddByte(b byte) { + for i := 7; i >= 0; i-- { + bl.AddBit(((b >> uint(i)) & 1) == 1) + } +} + +// AddBits appends the last (LSB) 'count' bits of 'b' the the end of the list +func (bl *BitList) AddBits(b int, count byte) { + for i := int(count) - 1; i >= 0; i-- { + bl.AddBit(((b >> uint(i)) & 1) == 1) + } +} + +// GetBytes returns all bits of the BitList as a []byte +func (bl *BitList) GetBytes() []byte { + len := bl.count >> 3 + if (bl.count % 8) != 0 { + len++ + } + result := make([]byte, len) + for i := 0; i < len; i++ { + shift := (3 - (i % 4)) * 8 + result[i] = (byte)((bl.data[i/4] >> uint(shift)) & 0xFF) + } + return result +} + +// IterateBytes iterates through all bytes contained in the BitList +func (bl *BitList) IterateBytes() <-chan byte { + res := make(chan byte) + + go func() { + c := bl.count + shift := 24 + i := 0 + for c > 0 { + res <- byte((bl.data[i] >> uint(shift)) & 0xFF) + shift -= 8 + if shift < 0 { + shift = 24 + i++ + } + c -= 8 + } + close(res) + }() + + return res +} diff --git a/vendor/github.com/boombuler/barcode/utils/galoisfield.go b/vendor/github.com/boombuler/barcode/utils/galoisfield.go new file mode 100644 index 000000000..68726fbfd --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/galoisfield.go @@ -0,0 +1,65 @@ +package utils + +// GaloisField encapsulates galois field arithmetics +type GaloisField struct { + Size int + Base int + ALogTbl []int + LogTbl []int +} + +// NewGaloisField creates a new galois field +func NewGaloisField(pp, fieldSize, b int) *GaloisField { + result := new(GaloisField) + + result.Size = fieldSize + result.Base = b + result.ALogTbl = make([]int, fieldSize) + result.LogTbl = make([]int, fieldSize) + + x := 1 + for i := 0; i < fieldSize; i++ { + result.ALogTbl[i] = x + x = x * 2 + if x >= fieldSize { + x = (x ^ pp) & (fieldSize - 1) + } + } + + for i := 0; i < fieldSize; i++ { + result.LogTbl[result.ALogTbl[i]] = int(i) + } + + return result +} + +func (gf *GaloisField) Zero() *GFPoly { + return NewGFPoly(gf, []int{0}) +} + +// AddOrSub add or substract two numbers +func (gf *GaloisField) AddOrSub(a, b int) int { + return a ^ b +} + +// Multiply multiplys two numbers +func (gf *GaloisField) Multiply(a, b int) int { + if a == 0 || b == 0 { + return 0 + } + return gf.ALogTbl[(gf.LogTbl[a]+gf.LogTbl[b])%(gf.Size-1)] +} + +// Divide divides two numbers +func (gf *GaloisField) Divide(a, b int) int { + if b == 0 { + panic("divide by zero") + } else if a == 0 { + return 0 + } + return gf.ALogTbl[(gf.LogTbl[a]-gf.LogTbl[b])%(gf.Size-1)] +} + +func (gf *GaloisField) Invers(num int) int { + return gf.ALogTbl[(gf.Size-1)-gf.LogTbl[num]] +} diff --git a/vendor/github.com/boombuler/barcode/utils/gfpoly.go b/vendor/github.com/boombuler/barcode/utils/gfpoly.go new file mode 100644 index 000000000..c56bb40b9 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/gfpoly.go @@ -0,0 +1,103 @@ +package utils + +type GFPoly struct { + gf *GaloisField + Coefficients []int +} + +func (gp *GFPoly) Degree() int { + return len(gp.Coefficients) - 1 +} + +func (gp *GFPoly) Zero() bool { + return gp.Coefficients[0] == 0 +} + +// GetCoefficient returns the coefficient of x ^ degree +func (gp *GFPoly) GetCoefficient(degree int) int { + return gp.Coefficients[gp.Degree()-degree] +} + +func (gp *GFPoly) AddOrSubstract(other *GFPoly) *GFPoly { + if gp.Zero() { + return other + } else if other.Zero() { + return gp + } + smallCoeff := gp.Coefficients + largeCoeff := other.Coefficients + if len(smallCoeff) > len(largeCoeff) { + largeCoeff, smallCoeff = smallCoeff, largeCoeff + } + sumDiff := make([]int, len(largeCoeff)) + lenDiff := len(largeCoeff) - len(smallCoeff) + copy(sumDiff, largeCoeff[:lenDiff]) + for i := lenDiff; i < len(largeCoeff); i++ { + sumDiff[i] = int(gp.gf.AddOrSub(int(smallCoeff[i-lenDiff]), int(largeCoeff[i]))) + } + return NewGFPoly(gp.gf, sumDiff) +} + +func (gp *GFPoly) MultByMonominal(degree int, coeff int) *GFPoly { + if coeff == 0 { + return gp.gf.Zero() + } + size := len(gp.Coefficients) + result := make([]int, size+degree) + for i := 0; i < size; i++ { + result[i] = int(gp.gf.Multiply(int(gp.Coefficients[i]), int(coeff))) + } + return NewGFPoly(gp.gf, result) +} + +func (gp *GFPoly) Multiply(other *GFPoly) *GFPoly { + if gp.Zero() || other.Zero() { + return gp.gf.Zero() + } + aCoeff := gp.Coefficients + aLen := len(aCoeff) + bCoeff := other.Coefficients + bLen := len(bCoeff) + product := make([]int, aLen+bLen-1) + for i := 0; i < aLen; i++ { + ac := int(aCoeff[i]) + for j := 0; j < bLen; j++ { + bc := int(bCoeff[j]) + product[i+j] = int(gp.gf.AddOrSub(int(product[i+j]), gp.gf.Multiply(ac, bc))) + } + } + return NewGFPoly(gp.gf, product) +} + +func (gp *GFPoly) Divide(other *GFPoly) (quotient *GFPoly, remainder *GFPoly) { + quotient = gp.gf.Zero() + remainder = gp + fld := gp.gf + denomLeadTerm := other.GetCoefficient(other.Degree()) + inversDenomLeadTerm := fld.Invers(int(denomLeadTerm)) + for remainder.Degree() >= other.Degree() && !remainder.Zero() { + degreeDiff := remainder.Degree() - other.Degree() + scale := int(fld.Multiply(int(remainder.GetCoefficient(remainder.Degree())), inversDenomLeadTerm)) + term := other.MultByMonominal(degreeDiff, scale) + itQuot := NewMonominalPoly(fld, degreeDiff, scale) + quotient = quotient.AddOrSubstract(itQuot) + remainder = remainder.AddOrSubstract(term) + } + return +} + +func NewMonominalPoly(field *GaloisField, degree int, coeff int) *GFPoly { + if coeff == 0 { + return field.Zero() + } + result := make([]int, degree+1) + result[0] = coeff + return NewGFPoly(field, result) +} + +func NewGFPoly(field *GaloisField, coefficients []int) *GFPoly { + for len(coefficients) > 1 && coefficients[0] == 0 { + coefficients = coefficients[1:] + } + return &GFPoly{field, coefficients} +} diff --git a/vendor/github.com/boombuler/barcode/utils/reedsolomon.go b/vendor/github.com/boombuler/barcode/utils/reedsolomon.go new file mode 100644 index 000000000..53af91ad4 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/reedsolomon.go @@ -0,0 +1,44 @@ +package utils + +import ( + "sync" +) + +type ReedSolomonEncoder struct { + gf *GaloisField + polynomes []*GFPoly + m *sync.Mutex +} + +func NewReedSolomonEncoder(gf *GaloisField) *ReedSolomonEncoder { + return &ReedSolomonEncoder{ + gf, []*GFPoly{NewGFPoly(gf, []int{1})}, new(sync.Mutex), + } +} + +func (rs *ReedSolomonEncoder) getPolynomial(degree int) *GFPoly { + rs.m.Lock() + defer rs.m.Unlock() + + if degree >= len(rs.polynomes) { + last := rs.polynomes[len(rs.polynomes)-1] + for d := len(rs.polynomes); d <= degree; d++ { + next := last.Multiply(NewGFPoly(rs.gf, []int{1, rs.gf.ALogTbl[d-1+rs.gf.Base]})) + rs.polynomes = append(rs.polynomes, next) + last = next + } + } + return rs.polynomes[degree] +} + +func (rs *ReedSolomonEncoder) Encode(data []int, eccCount int) []int { + generator := rs.getPolynomial(eccCount) + info := NewGFPoly(rs.gf, data) + info = info.MultByMonominal(eccCount, 1) + _, remainder := info.Divide(generator) + + result := make([]int, eccCount) + numZero := int(eccCount) - len(remainder.Coefficients) + copy(result[numZero:], remainder.Coefficients) + return result +} diff --git a/vendor/github.com/boombuler/barcode/utils/runeint.go b/vendor/github.com/boombuler/barcode/utils/runeint.go new file mode 100644 index 000000000..d2e5e61e5 --- /dev/null +++ b/vendor/github.com/boombuler/barcode/utils/runeint.go @@ -0,0 +1,19 @@ +package utils + +// RuneToInt converts a rune between '0' and '9' to an integer between 0 and 9 +// If the rune is outside of this range -1 is returned. +func RuneToInt(r rune) int { + if r >= '0' && r <= '9' { + return int(r - '0') + } + return -1 +} + +// IntToRune converts a digit 0 - 9 to the rune '0' - '9'. If the given int is outside +// of this range 'F' is returned! +func IntToRune(i int) rune { + if i >= 0 && i <= 9 { + return rune(i + '0') + } + return 'F' +} diff --git a/vendor/github.com/pquerna/otp/LICENSE b/vendor/github.com/pquerna/otp/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/vendor/github.com/pquerna/otp/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/pquerna/otp/NOTICE b/vendor/github.com/pquerna/otp/NOTICE new file mode 100644 index 000000000..50e2e7501 --- /dev/null +++ b/vendor/github.com/pquerna/otp/NOTICE @@ -0,0 +1,5 @@ +otp +Copyright (c) 2014, Paul Querna + +This product includes software developed by +Paul Querna (http://paul.querna.org/). diff --git a/vendor/github.com/pquerna/otp/README.md b/vendor/github.com/pquerna/otp/README.md new file mode 100644 index 000000000..148e8980d --- /dev/null +++ b/vendor/github.com/pquerna/otp/README.md @@ -0,0 +1,60 @@ +# otp: One Time Password utilities Go / Golang + +[![GoDoc](https://godoc.org/github.com/pquerna/otp?status.svg)](https://godoc.org/github.com/pquerna/otp) [![Build Status](https://travis-ci.org/pquerna/otp.svg?branch=master)](https://travis-ci.org/pquerna/otp) + +# Why One Time Passwords? + +One Time Passwords (OTPs) are an mechanism to improve security over passwords alone. When a Time-based OTP (TOTP) is stored on a user's phone, and combined with something the user knows (Password), you have an easy on-ramp to [Multi-factor authentication](http://en.wikipedia.org/wiki/Multi-factor_authentication) without adding a dependency on a SMS provider. This Password and TOTP combination is used by many popular websites including Google, Github, Facebook, Salesforce and many others. + +The `otp` library enables you to easily add TOTPs to your own application, increasing your user's security against mass-password breaches and malware. + +Because TOTP is standardized and widely deployed, there are many [mobile clients and software implementations](http://en.wikipedia.org/wiki/Time-based_One-time_Password_Algorithm#Client_implementations). + +## `otp` Supports: + +* Generating QR Code images for easy user enrollment. +* Time-based One-time Password Algorithm (TOTP) (RFC 6238): Time based OTP, the most commonly used method. +* HMAC-based One-time Password Algorithm (HOTP) (RFC 4226): Counter based OTP, which TOTP is based upon. +* Generation and Validation of codes for either algorithm. + +## Implementing TOTP in your application: + +### User Enrollment + +For an example of a working enrollment work flow, [Github has documented theirs](https://help.github.com/articles/configuring-two-factor-authentication-via-a-totp-mobile-app/ +), but the basics are: + +1. Generate new TOTP Key for a User. `key,_ := totp.Generate(...)`. +1. Display the Key's Secret and QR-Code for the User. `key.Secret()` and `key.Image(...)`. +1. Test that the user can successfully use their TOTP. `totp.Validate(...)`. +1. Store TOTP Secret for the User in your backend. `key.Secret()` +1. Provide the user with "recovery codes". (See Recovery Codes bellow) + +### Code Generation + +* In either TOTP or HOTP cases, use the `GenerateCode` function and a counter or + `time.Time` struct to generate a valid code compatible with most implementations. +* For uncommon or custom settings, or to catch unlikely errors, use `GenerateCodeCustom` + in either module. + +### Validation + +1. Prompt and validate User's password as normal. +1. If the user has TOTP enabled, prompt for TOTP passcode. +1. Retrieve the User's TOTP Secret from your backend. +1. Validate the user's passcode. `totp.Validate(...)` + + +### Recovery Codes + +When a user loses access to their TOTP device, they would no longer have access to their account. Because TOTPs are often configured on mobile devices that can be lost, stolen or damaged, this is a common problem. For this reason many providers give their users "backup codes" or "recovery codes". These are a set of one time use codes that can be used instead of the TOTP. These can simply be randomly generated strings that you store in your backend. [Github's documentation provides an overview of the user experience]( +https://help.github.com/articles/downloading-your-two-factor-authentication-recovery-codes/). + + +## Improvements, bugs, adding feature, etc: + +Please [open issues in Github](https://github.com/pquerna/otp/issues) for ideas, bugs, and general thoughts. Pull requests are of course preferred :) + +## License + +`otp` is licensed under the [Apache License, Version 2.0](./LICENSE) diff --git a/vendor/github.com/pquerna/otp/doc.go b/vendor/github.com/pquerna/otp/doc.go new file mode 100644 index 000000000..b8b4c8cc1 --- /dev/null +++ b/vendor/github.com/pquerna/otp/doc.go @@ -0,0 +1,70 @@ +/** + * Copyright 2014 Paul Querna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package otp implements both HOTP and TOTP based +// one time passcodes in a Google Authenticator compatible manner. +// +// When adding a TOTP for a user, you must store the "secret" value +// persistently. It is recommend to store the secret in an encrypted field in your +// datastore. Due to how TOTP works, it is not possible to store a hash +// for the secret value like you would a password. +// +// To enroll a user, you must first generate an OTP for them. Google +// Authenticator supports using a QR code as an enrollment method: +// +// import ( +// "github.com/pquerna/otp/totp" +// +// "bytes" +// "image/png" +// ) +// +// key, err := totp.Generate(totp.GenerateOpts{ +// Issuer: "Example.com", +// AccountName: "alice@example.com", +// }) +// +// // Convert TOTP key into a QR code encoded as a PNG image. +// var buf bytes.Buffer +// img, err := key.Image(200, 200) +// png.Encode(&buf, img) +// +// // display the QR code to the user. +// display(buf.Bytes()) +// +// // Now Validate that the user's successfully added the passcode. +// passcode := promptForPasscode() +// valid := totp.Validate(passcode, key.Secret()) +// +// if valid { +// // User successfully used their TOTP, save it to your backend! +// storeSecret("alice@example.com", key.Secret()) +// } +// +// Validating a TOTP passcode is very easy, just prompt the user for a passcode +// and retrieve the associated user's previously stored secret. +// import "github.com/pquerna/otp/totp" +// +// passcode := promptForPasscode() +// secret := getSecret("alice@example.com") +// +// valid := totp.Validate(passcode, secret) +// +// if valid { +// // Success! continue login process. +// } +package otp diff --git a/vendor/github.com/pquerna/otp/example/main.go b/vendor/github.com/pquerna/otp/example/main.go new file mode 100644 index 000000000..77a3e3beb --- /dev/null +++ b/vendor/github.com/pquerna/otp/example/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "bufio" + "bytes" + "fmt" + "image/png" + "io/ioutil" + "os" +) + +func display(key *otp.Key, data []byte) { + fmt.Printf("Issuer: %s\n", key.Issuer()) + fmt.Printf("Account Name: %s\n", key.AccountName()) + fmt.Printf("Secret: %s\n", key.Secret()) + fmt.Println("Writing PNG to qr-code.png....") + ioutil.WriteFile("qr-code.png", data, 0644) + fmt.Println("") + fmt.Println("Please add your TOTP to your OTP Application now!") + fmt.Println("") +} + +func promptForPasscode() string { + reader := bufio.NewReader(os.Stdin) + fmt.Print("Enter Passcode: ") + text, _ := reader.ReadString('\n') + return text +} + +func main() { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "Example.com", + AccountName: "alice@example.com", + }) + if err != nil { + panic(err) + } + // Convert TOTP key into a PNG + var buf bytes.Buffer + img, err := key.Image(200, 200) + if err != nil { + panic(err) + } + png.Encode(&buf, img) + + // display the QR code to the user. + display(key, buf.Bytes()) + + // Now Validate that the user's successfully added the passcode. + fmt.Println("Validating TOTP...") + passcode := promptForPasscode() + valid := totp.Validate(passcode, key.Secret()) + if valid { + println("Valid passcode!") + os.Exit(0) + } else { + println("Invalid passocde!") + os.Exit(1) + } +} diff --git a/vendor/github.com/pquerna/otp/hotp/hotp.go b/vendor/github.com/pquerna/otp/hotp/hotp.go new file mode 100644 index 000000000..ced7d8e28 --- /dev/null +++ b/vendor/github.com/pquerna/otp/hotp/hotp.go @@ -0,0 +1,187 @@ +/** + * Copyright 2014 Paul Querna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package hotp + +import ( + "github.com/pquerna/otp" + + "crypto/hmac" + "crypto/rand" + "crypto/subtle" + "encoding/base32" + "encoding/binary" + "fmt" + "math" + "net/url" + "strings" +) + +const debug = false + +// Validate a HOTP passcode given a counter and secret. +// This is a shortcut for ValidateCustom, with parameters that +// are compataible with Google-Authenticator. +func Validate(passcode string, counter uint64, secret string) bool { + rv, _ := ValidateCustom( + passcode, + counter, + secret, + ValidateOpts{ + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }, + ) + return rv +} + +// ValidateOpts provides options for ValidateCustom(). +type ValidateOpts struct { + // Digits as part of the input. Defaults to 6. + Digits otp.Digits + // Algorithm to use for HMAC. Defaults to SHA1. + Algorithm otp.Algorithm +} + +// GenerateCode creates a HOTP passcode given a counter and secret. +// This is a shortcut for GenerateCodeCustom, with parameters that +// are compataible with Google-Authenticator. +func GenerateCode(secret string, counter uint64) (string, error) { + return GenerateCodeCustom(secret, counter, ValidateOpts{ + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +// GenerateCodeCustom uses a counter and secret value and options struct to +// create a passcode. +func GenerateCodeCustom(secret string, counter uint64, opts ValidateOpts) (passcode string, err error) { + // As noted in issue #10 this adds support for TOTP secrets that are + // missing their padding. + if n := len(secret) % 8; n != 0 { + secret = secret + strings.Repeat("=", 8-n) + } + + secretBytes, err := base32.StdEncoding.DecodeString(secret) + if err != nil { + return "", otp.ErrValidateSecretInvalidBase32 + } + + buf := make([]byte, 8) + mac := hmac.New(opts.Algorithm.Hash, secretBytes) + binary.BigEndian.PutUint64(buf, counter) + if debug { + fmt.Printf("counter=%v\n", counter) + fmt.Printf("buf=%v\n", buf) + } + + mac.Write(buf) + sum := mac.Sum(nil) + + // "Dynamic truncation" in RFC 4226 + // http://tools.ietf.org/html/rfc4226#section-5.4 + offset := sum[len(sum)-1] & 0xf + value := int64(((int(sum[offset]) & 0x7f) << 24) | + ((int(sum[offset+1] & 0xff)) << 16) | + ((int(sum[offset+2] & 0xff)) << 8) | + (int(sum[offset+3]) & 0xff)) + + l := opts.Digits.Length() + mod := int32(value % int64(math.Pow10(l))) + + if debug { + fmt.Printf("offset=%v\n", offset) + fmt.Printf("value=%v\n", value) + fmt.Printf("mod'ed=%v\n", mod) + } + + return opts.Digits.Format(mod), nil +} + +// ValidateCustom validates an HOTP with customizable options. Most users should +// use Validate(). +func ValidateCustom(passcode string, counter uint64, secret string, opts ValidateOpts) (bool, error) { + passcode = strings.TrimSpace(passcode) + + if len(passcode) != opts.Digits.Length() { + return false, otp.ErrValidateInputInvalidLength + } + + otpstr, err := GenerateCodeCustom(secret, counter, opts) + if err != nil { + return false, err + } + + if subtle.ConstantTimeCompare([]byte(otpstr), []byte(passcode)) == 1 { + return true, nil + } + + return false, nil +} + +// GenerateOpts provides options for .Generate() +type GenerateOpts struct { + // Name of the issuing Organization/Company. + Issuer string + // Name of the User's Account (eg, email address) + AccountName string + // Size in size of the generated Secret. Defaults to 10 bytes. + SecretSize uint + // Digits to request. Defaults to 6. + Digits otp.Digits + // Algorithm to use for HMAC. Defaults to SHA1. + Algorithm otp.Algorithm +} + +// Generate creates a new HOTP Key. +func Generate(opts GenerateOpts) (*otp.Key, error) { + // url encode the Issuer/AccountName + if opts.Issuer == "" { + return nil, otp.ErrGenerateMissingIssuer + } + + if opts.AccountName == "" { + return nil, otp.ErrGenerateMissingAccountName + } + + if opts.SecretSize == 0 { + opts.SecretSize = 10 + } + + // otpauth://totp/Example:alice@google.com?secret=JBSWY3DPEHPK3PXP&issuer=Example + + v := url.Values{} + secret := make([]byte, opts.SecretSize) + _, err := rand.Read(secret) + if err != nil { + return nil, err + } + + v.Set("secret", base32.StdEncoding.EncodeToString(secret)) + v.Set("issuer", opts.Issuer) + v.Set("algorithm", opts.Algorithm.String()) + v.Set("digits", opts.Digits.String()) + + u := url.URL{ + Scheme: "otpauth", + Host: "hotp", + Path: "/" + opts.Issuer + ":" + opts.AccountName, + RawQuery: v.Encode(), + } + + return otp.NewKeyFromURL(u.String()) +} diff --git a/vendor/github.com/pquerna/otp/otp.go b/vendor/github.com/pquerna/otp/otp.go new file mode 100644 index 000000000..0fa970927 --- /dev/null +++ b/vendor/github.com/pquerna/otp/otp.go @@ -0,0 +1,200 @@ +/** + * Copyright 2014 Paul Querna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package otp + +import ( + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" + + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" + "image" + "net/url" + "strings" +) + +// Error when attempting to convert the secret from base32 to raw bytes. +var ErrValidateSecretInvalidBase32 = errors.New("Decoding of secret as base32 failed.") + +// The user provided passcode length was not expected. +var ErrValidateInputInvalidLength = errors.New("Input length unexpected") + +// When generating a Key, the Issuer must be set. +var ErrGenerateMissingIssuer = errors.New("Issuer must be set") + +// When generating a Key, the Account Name must be set. +var ErrGenerateMissingAccountName = errors.New("AccountName must be set") + +// Key represents an TOTP or HTOP key. +type Key struct { + orig string + url *url.URL +} + +// NewKeyFromURL creates a new Key from an TOTP or HOTP url. +// +// The URL format is documented here: +// https://github.com/google/google-authenticator/wiki/Key-Uri-Format +// +func NewKeyFromURL(orig string) (*Key, error) { + u, err := url.Parse(orig) + + if err != nil { + return nil, err + } + + return &Key{ + orig: orig, + url: u, + }, nil +} + +func (k *Key) String() string { + return k.orig +} + +// Image returns an QR-Code image of the specified width and height, +// suitable for use by many clients like Google-Authenricator +// to enroll a user's TOTP/HOTP key. +func (k *Key) Image(width int, height int) (image.Image, error) { + b, err := qr.Encode(k.orig, qr.M, qr.Auto) + + if err != nil { + return nil, err + } + + b, err = barcode.Scale(b, width, height) + + if err != nil { + return nil, err + } + + return b, nil +} + +// Type returns "hotp" or "totp". +func (k *Key) Type() string { + return k.url.Host +} + +// Issuer returns the name of the issuing organization. +func (k *Key) Issuer() string { + q := k.url.Query() + + issuer := q.Get("issuer") + + if issuer != "" { + return issuer + } + + p := strings.TrimPrefix(k.url.Path, "/") + i := strings.Index(p, ":") + + if i == -1 { + return "" + } + + return p[:i] +} + +// AccountName returns the name of the user's account. +func (k *Key) AccountName() string { + p := strings.TrimPrefix(k.url.Path, "/") + i := strings.Index(p, ":") + + if i == -1 { + return p + } + + return p[i+1:] +} + +// Secret returns the opaque secret for this Key. +func (k *Key) Secret() string { + q := k.url.Query() + + return q.Get("secret") +} + +// Algorithm represents the hashing function to use in the HMAC +// operation needed for OTPs. +type Algorithm int + +const ( + AlgorithmSHA1 Algorithm = iota + AlgorithmSHA256 + AlgorithmSHA512 + AlgorithmMD5 +) + +func (a Algorithm) String() string { + switch a { + case AlgorithmSHA1: + return "SHA1" + case AlgorithmSHA256: + return "SHA256" + case AlgorithmSHA512: + return "SHA512" + case AlgorithmMD5: + return "MD5" + } + panic("unreached") +} + +func (a Algorithm) Hash() hash.Hash { + switch a { + case AlgorithmSHA1: + return sha1.New() + case AlgorithmSHA256: + return sha256.New() + case AlgorithmSHA512: + return sha512.New() + case AlgorithmMD5: + return md5.New() + } + panic("unreached") +} + +// Digits represents the number of digits present in the +// user's OTP passcode. Six and Eight are the most common values. +type Digits int + +const ( + DigitsSix Digits = 6 + DigitsEight Digits = 8 +) + +// Format converts an integer into the zero-filled size for this Digits. +func (d Digits) Format(in int32) string { + f := fmt.Sprintf("%%0%dd", d) + return fmt.Sprintf(f, in) +} + +// Length returns the number of characters for this Digits. +func (d Digits) Length() int { + return int(d) +} + +func (d Digits) String() string { + return fmt.Sprintf("%d", d) +} diff --git a/vendor/github.com/pquerna/otp/totp/totp.go b/vendor/github.com/pquerna/otp/totp/totp.go new file mode 100644 index 000000000..af5ab8296 --- /dev/null +++ b/vendor/github.com/pquerna/otp/totp/totp.go @@ -0,0 +1,191 @@ +/** + * Copyright 2014 Paul Querna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package totp + +import ( + "github.com/pquerna/otp" + "github.com/pquerna/otp/hotp" + + "crypto/rand" + "encoding/base32" + "math" + "net/url" + "strconv" + "time" +) + +// Validate a TOTP using the current time. +// A shortcut for ValidateCustom, Validate uses a configuration +// that is compatible with Google-Authenticator and most clients. +func Validate(passcode string, secret string) bool { + rv, _ := ValidateCustom( + passcode, + secret, + time.Now().UTC(), + ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }, + ) + return rv +} + +// GenerateCode creates a TOTP token using the current time. +// A shortcut for GenerateCodeCustom, GenerateCode uses a configuration +// that is compatible with Google-Authenticator and most clients. +func GenerateCode(secret string, t time.Time) (string, error) { + return GenerateCodeCustom(secret, t, ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +// ValidateOpts provides options for ValidateCustom(). +type ValidateOpts struct { + // Number of seconds a TOTP hash is valid for. Defaults to 30 seconds. + Period uint + // Periods before or after the current time to allow. Value of 1 allows up to Period + // of either side of the specified time. Defaults to 0 allowed skews. Values greater + // than 1 are likely sketchy. + Skew uint + // Digits as part of the input. Defaults to 6. + Digits otp.Digits + // Algorithm to use for HMAC. Defaults to SHA1. + Algorithm otp.Algorithm +} + +// GenerateCodeCustom takes a timepoint and produces a passcode using a +// secret and the provided opts. (Under the hood, this is making an adapted +// call to hotp.GenerateCodeCustom) +func GenerateCodeCustom(secret string, t time.Time, opts ValidateOpts) (passcode string, err error) { + if opts.Period == 0 { + opts.Period = 30 + } + counter := uint64(math.Floor(float64(t.Unix()) / float64(opts.Period))) + passcode, err = hotp.GenerateCodeCustom(secret, counter, hotp.ValidateOpts{ + Digits: opts.Digits, + Algorithm: opts.Algorithm, + }) + if err != nil { + return "", err + } + return passcode, nil +} + +// ValidateCustom validates a TOTP given a user specified time and custom options. +// Most users should use Validate() to provide an interpolatable TOTP experience. +func ValidateCustom(passcode string, secret string, t time.Time, opts ValidateOpts) (bool, error) { + if opts.Period == 0 { + opts.Period = 30 + } + + counters := []uint64{} + counter := int64(math.Floor(float64(t.Unix()) / float64(opts.Period))) + + counters = append(counters, uint64(counter)) + for i := 1; i <= int(opts.Skew); i++ { + counters = append(counters, uint64(counter+int64(i))) + counters = append(counters, uint64(counter-int64(i))) + } + + for _, counter := range counters { + rv, err := hotp.ValidateCustom(passcode, counter, secret, hotp.ValidateOpts{ + Digits: opts.Digits, + Algorithm: opts.Algorithm, + }) + + if err != nil { + return false, err + } + + if rv == true { + return true, nil + } + } + + return false, nil +} + +// GenerateOpts provides options for Generate(). The default values +// are compatible with Google-Authenticator. +type GenerateOpts struct { + // Name of the issuing Organization/Company. + Issuer string + // Name of the User's Account (eg, email address) + AccountName string + // Number of seconds a TOTP hash is valid for. Defaults to 30 seconds. + Period uint + // Size in size of the generated Secret. Defaults to 10 bytes. + SecretSize uint + // Digits to request. Defaults to 6. + Digits otp.Digits + // Algorithm to use for HMAC. Defaults to SHA1. + Algorithm otp.Algorithm +} + +// Generate a new TOTP Key. +func Generate(opts GenerateOpts) (*otp.Key, error) { + // url encode the Issuer/AccountName + if opts.Issuer == "" { + return nil, otp.ErrGenerateMissingIssuer + } + + if opts.AccountName == "" { + return nil, otp.ErrGenerateMissingAccountName + } + + if opts.Period == 0 { + opts.Period = 30 + } + + if opts.SecretSize == 0 { + opts.SecretSize = 10 + } + + if opts.Digits == 0 { + opts.Digits = otp.DigitsSix + } + + // otpauth://totp/Example:alice@google.com?secret=JBSWY3DPEHPK3PXP&issuer=Example + + v := url.Values{} + secret := make([]byte, opts.SecretSize) + _, err := rand.Read(secret) + if err != nil { + return nil, err + } + + v.Set("secret", base32.StdEncoding.EncodeToString(secret)) + v.Set("issuer", opts.Issuer) + v.Set("period", strconv.FormatUint(uint64(opts.Period), 10)) + v.Set("algorithm", opts.Algorithm.String()) + v.Set("digits", opts.Digits.String()) + + u := url.URL{ + Scheme: "otpauth", + Host: "totp", + Path: "/" + opts.Issuer + ":" + opts.AccountName, + RawQuery: v.Encode(), + } + + return otp.NewKeyFromURL(u.String()) +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 4f5912d19..9e769960d 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -354,6 +354,18 @@ "revision": "675b82c74c0ed12283ee81ba8a534c8982c07b85", "revisionTime": "2016-10-13T10:26:35Z" }, + { + "checksumSHA1": "tLl952GRIVsso2Pk/IH3cMJaK8E=", + "path": "github.com/boombuler/barcode", + "revision": "9fb68fa6ca3535187c2a32f11a25e8d58f294bed", + "revisionTime": "2017-04-12T13:03:35Z" + }, + { + "checksumSHA1": "tbwzn+sWiZv6veyXae3qRfTjlcQ=", + "path": "github.com/boombuler/barcode/qr", + "revision": "9fb68fa6ca3535187c2a32f11a25e8d58f294bed", + "revisionTime": "2017-04-12T13:03:35Z" + }, { "checksumSHA1": "gX06B03sIRw/1yCms1kMwKX8krE=", "path": "github.com/cenk/backoff", @@ -1152,6 +1164,30 @@ "revision": "ff09b135c25aae272398c51a07235b90a75aa4f0", "revisionTime": "2017-03-16T20:15:38Z" }, + { + "checksumSHA1": "woY3inKe+d7B1jPTFxVKNCCFH9c=", + "path": "github.com/pquerna/otp", + "revision": "9e1935374bc73ffe011187dafed51a412b90fe43", + "revisionTime": "2017-02-23T01:06:52Z" + }, + { + "checksumSHA1": "5xpnYLhCOqNnsgykOk85MnTqVu0=", + "path": "github.com/pquerna/otp/example", + "revision": "9e1935374bc73ffe011187dafed51a412b90fe43", + "revisionTime": "2017-02-23T01:06:52Z" + }, + { + "checksumSHA1": "xo32aXW4ZXXRHJ/9E6m10vXJZAo=", + "path": "github.com/pquerna/otp/hotp", + "revision": "9e1935374bc73ffe011187dafed51a412b90fe43", + "revisionTime": "2017-02-23T01:06:52Z" + }, + { + "checksumSHA1": "Ie55pTQw1rnOZ8KDekSDXUWDT1I=", + "path": "github.com/pquerna/otp/totp", + "revision": "9e1935374bc73ffe011187dafed51a412b90fe43", + "revisionTime": "2017-02-23T01:06:52Z" + }, { "checksumSHA1": "ZOhewV1DsQjTYlx8a+ifrZki2Vg=", "path": "github.com/ryanuber/columnize", diff --git a/website/source/api/secret/totp/index.html.md b/website/source/api/secret/totp/index.html.md new file mode 100644 index 000000000..3cc6e5b06 --- /dev/null +++ b/website/source/api/secret/totp/index.html.md @@ -0,0 +1,272 @@ +--- +layout: "api" +page_title: "TOTP Secret Backend - HTTP API" +sidebar_current: "docs-http-secret-totp" +description: |- + This is the API documentation for the Vault TOTP secret backend. +--- + +# TOTP Secret Backend HTTP API + +This is the API documentation for the Vault TOTP secret backend. For +general information about the usage and operation of the TOTP backend, +please see the +[Vault TOTP backend documentation](/docs/secrets/totp/index.html). + +This documentation assumes the TOTP backend is mounted at the +`/totp` path in Vault. Since it is possible to mount secret backends at +any location, please update your API calls accordingly. + +## Create Key + +This endpoint creates or updates a key definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------------------------------------------------------------------------------- | +| `POST` | `/totp/keys/:name` | if generating a key and exported is true: `200 application/json` else: `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the key to create. This is specified as part of the URL. + +- `generate` `(bool: false)` – Specifies if a key should be generated by Vault or if a key is being passed from another service. + +- `exported` `(bool: true)` – Specifies if a QR code and url are returned upon generating a key. Only used if generate is true. + +- `key_size` `(int: 20)` – Specifies the size in bytes of the Vault generated key. Only used if generate is true. + +- `url` `(string: "")` – Specifies the TOTP key url string that can be used to configure a key. Only used if generate is false. + +- `key` `(string: )` – Specifies the master key used to generate a TOTP code. Only used if generate is false. + +- `issuer` `(string: "" )` – Specifies the name of the key’s issuing organization. + +- `account_name` `(string: "" )` – Specifies the name of the account associated with the key. + +- `period` `(int or duration format string: 30)` – Specifies the length of time in seconds used to generate a counter for the TOTP code calculation. + +- `algorithm` `(string: "SHA1")` – Specifies the hashing algorithm used to generate the TOTP code. Options include "SHA1", "SHA256" and "SHA512". + +- `digits` `(int: 6)` – Specifies the number of digits in the generated TOTP code. This value can be set to 6 or 8. + +- `skew` `(int: 1)` – Specifies the number of delay periods that are allowed when validating a TOTP code. This value can be either 0 or 1. Only used if generate is true. + +- `qr_size` `(int: 200)` – Specifies the pixel size of the square QR code when generating a new key. Only used if generate is true and exported is true. If this value is 0, a QR code will not be returned. + +### Sample Payload + +```json +{ + "url": "otpauth://totp/Google:test@gmail.com?secret=Y64VEVMBTSXCYIWRSHRNDZW62MPGVU2G&issuer=Google" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/totp/keys/my-key +``` + +### Sample Payload + +```json +{ + "generate": true, + "issuer": "Google", + "account_name": "test@gmail.com", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/totp/keys/my-key +``` + +### Sample Response + +```json +{ + "data": { + "barcode": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADIEAAAAADYoy0BAAAGXklEQVR4nOyd4Y4iOQyEmRPv/8p7upX6BJm4XbbDbK30fT9GAtJJhpLjdhw3z1+/HmDEP396AvDO878/X1+9i1frWvu5Po/6Xz+P2kft1nFVa1f7z+YdjT/5PrEQMxDEDAQx4/n6orsGr6z9ZP1mviMbP/MBav/R6/U61Ud0vk8sxAwEMQNBzHju3lTvv6P2ajwS9Ve9zz+9pkfjRp+r/SjzwULMQBAzEMSMrQ/pUo0bouun7dW9LXVvrBq/TMBCzEAQMxDEjKM+JFqT17W4mu9Y+49eq/OL3r/GVX3CJ7KtWIgZCGIGgpix9SHTtXGa4476qfoa1adVc+HV/6/yfWIhZiCIGQhixpsP6Z4nulD3lqavV7q+Yvo6G7/zfWIhZiCIGQhixteJ/Rh1Da3e71d9RjRul2ocdeK7xELMQBAzEMSM3z6ku6dTrdOo1l9M6y5O7clVx5n4SCzEDAQxA0HMuN3L+qlavqj9itpePY+VtVdrHqfzeQULMQNBzEAQM97ikAv1vr/brltTeCp/svarcjLe2F1PnbohCGIGgphRqjG8mJ6PmtYMVnP363Vqv6d8qZrzf2AhfiCIGQhixm0c8n+jQ8+7+jZ4cY3PrlfHO/1Ml+45st18sRAzEMQMBDHjdxyixgPqs0lWsvvwqH00zrSO41R80p3XXXssxAwEMQNBzJCeuaieo6pedzGtb1/76fqgLH6ofg+dZ65gIWYgiBkIYsbbs9/V+/EVde1V+62eh1I/r/qIrs+Ixo2uYy/LGAQxA0HMeNvLilDX1OraXc2jVNtPzxJXr6v+HzuwEDMQxAwEMWNbp95d21WmzzBR6066e07dPMq0XoW9LEMQxAwEMUOqUz+1p9ONd07Xz586u6yifp/4EEMQxAwEMUPay7rIcthqTrx6v1/NTX+qZrIbF63v34GFmIEgZiCIGdvfU++e1a3GM2oOPjtvpfbfjS+qeZFJXgcLMQNBzEAQM6Tn9p7OLVdrFqP5TFF9ZXTdqfqTV7AQMxDEDAQx482HdPMPGdN8SjeHr6710zzJidrCB/kQTxDEDAQxY7uXdTGNC9S9pK6vqs6nWzdyej53PhELMQNBzEAQM0o59YtTz/xQfVO3jmOdl0rmE6f5ort5YSFmIIgZCGLGbU69eka3ep+v5sCzcbp5jZXMR0zr+aPPqVM3BkHMQBAzRs/tjejmwj9d05ihzq96nQr5EEMQxAwEMWPrQy6q9/fdevFTcVA0v+n5K7U/tf4lGhcfYgiCmIEgZtw+6+RCXUurvkKlepZ2vS5i+oyTaby0GxcLMQNBzEAQM0r5kKnv6K6xK9X4R13zu+eyJnXpazssxAwEMQNBzNj+fkg3nqjGK9laPz1vleXwq2v+p+vciUMMQRAzEMSM298xrOYDVqrtpmtzt59uHqc6v2zcBxbiB4KYgSBmbOvUV7q577VdOIliXqLr87p7Tere2YnrsRAzEMQMBDFj+zuGar3Gp+rNp3kUtR5lmj/Jxo/GvZsvFmIGgpiBIGbcPi/rW+MPPaeqOs407xL1E1E9lzWpg8FCzEAQMxDEDOk3qC66a7f6fsSn1uz18+o8P+GzsBAzEMQMBDFjm1Ov7L3s3p+2/6lcfoa6ZxaNm50DWyEOMQRBzEAQM7Zne6PX3XilW5M3zbd0c/3ZHpvqY6P+7j7HQsxAEDMQxIxRPqRaT6Kuzemkh7WJ3RrJbJxq7eOuPyzEDAQxA0HMKJ3t/XbxobW/Gmdka/PpPMxPgoWYgSBmIIgZ0m9QrXTP1mb9Ru2y+/hsD2xaM9jN5UfjEIf8RSCIGQhiRus3qLp7ONU6jK4vynxMdn10XdY+m4/SHxZiBoKYgSBm3MYhGdl9/qkzvN18ilpDqF6nxiPVGs3Xz7EQMxDEDAQx4/ZcVoR6fqobZ6h7Vtm81TVejZdWuvHNXXssxAwEMQNBzHju3pyujdO68Ky9Wm+h9qPGJVG/6nyU+WIhZiCIGQhixtaHdFF9hlqLeOrcVPcMQDeOmtTNYyFmIIgZCGLGUR/SPQs73QuL5tGtiVznlc1X/T8iXtthIWYgiBkIYsbWh3T3nNS1dXqe6tReW8S0Hr1b5/LAQvxAEDMQxIw3H9I9nzU9R6XGHdn41dx4d4+rGp9En7OX9ReAIGYgiBlff6IWG2KwEDP+DQAA//+TDHXGhqE4+AAAAABJRU5ErkJggg==", + "url" : "otpauth://totp/Google:test@gmail.com?algorithm=SHA1&digits=6&issuer=Google&period=30&secret=HTXT7KJFVNAJUPYWQRWMNVQE5AF5YZI2", + } +} +``` + +If a QR code is returned, it consists of base64-formatted PNG bytes. You can embed it in a web page by including the base64 string in an `img` tag with the prefix `data:image/png;base64` + +``` + +``` + +## Read Key + +This endpoint queries the key definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/totp/keys/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the key to read. This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/totp/keys/my-key +``` + +### Sample Response + +```json +{ + "data": { + "account_name": "test@gmail.com", + "algorithm" : "SHA1", + "digits" : 6, + "issuer": "Google", + "period" : 30, + } +} +``` + +## List Keys + +This endpoint returns a list of available keys. Only the key names are +returned, not any values. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `LIST` | `/totp/keys` | `200 application/json` | + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request LIST \ + https://vault.rocks/v1/totp/keys +``` + +### Sample Response + +```json +{ + "auth": null, + "data": { + "keys": ["my-key"] + }, + "lease_duration": 0, + "lease_id": "", + "renewable": false +} +``` + +## Delete Key + +This endpoint deletes the key definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/totp/keys/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the key to delete. This + is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/totp/keys/my-key +``` + +## Generate Code + +This endpoint generates a new time-based one-time use password based on the named +key. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/totp/code/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the key to create + credentials against. This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/totp/code/my-key +``` + +### Sample Response + +```json +{ + "data": { + "code": "810920", + } +} +``` + +## Validate Code + +This endpoint validates a time-based one-time use password generated from the named +key. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/totp/code/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the key used to generate the password. This is specified as part of the URL. + +- `code` `(string: )` – Specifies the password you want to validate. + +### Sample Payload + +```json +{ + "code": "123802" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/totp/code/my-key +``` + +### Sample Response + +```json +{ + "data": { + "valid": true, + } +} +``` diff --git a/website/source/docs/secrets/totp/index.html.md b/website/source/docs/secrets/totp/index.html.md new file mode 100644 index 000000000..97cebceb8 --- /dev/null +++ b/website/source/docs/secrets/totp/index.html.md @@ -0,0 +1,83 @@ +--- +layout: "docs" +page_title: "TOTP Secret Backend" +sidebar_current: "docs-secrets-totp" +description: |- + The TOTP secret backend for Vault generates time-based one-time use passwords. +--- + +# TOTP Secret Backend + +Name: `totp` + +The TOTP secret backend for Vault will allow Vault users to store their multi-factor +authentication keys in Vault and use the API to retrieve time-based one-time use passwords +on demand. The backend can also be used to generate a new key and validate passwords generated by that key. + +This page will show a quick start for this backend. For detailed documentation +on every path, use `vault path-help` after mounting the backend. + +## Quick Start + +The first step to using the TOTP backend is to mount it. +Unlike the `generic` backend, the `totp` backend is not mounted by default. + +```text +$ vault mount totp +Successfully mounted 'totp' at 'totp'! +``` + +The next step is to configure a key. For example, lets create +a "test" key by passing in a TOTP key url: + +```text +$ vault write totp/keys/test \ + url="otpauth://totp/Vault:test@gmail.com?secret=Y64VEVMBTSXCYIWRSHRNDZW62MPGVU2G&issuer=Vault" +Success! Data written to: totp/keys/test +``` + +By writing to the `keys/test` path we are defining the `test` key. + +To generate a new set of credentials, we simply read from that key using the `code` path: + +```text +$ vault read totp/code/test +Key Value +code 135031 +``` +Vault is now configured to create time-based one-time use passwords! + +By reading from the `code/test` path, Vault has generated a new +time-based one-time use password using the `test` key configuration. + +Using ACLs, it is possible to restrict using the TOTP backend such +that trusted operators can manage the key definitions, and both +users and applications are restricted in the credentials they are +allowed to read. + +The TOTP backend can also be used to generate new keys and validate passwords generated using those keys. + +In order to generate a new key, set the generate flag to true and pass in an issuer and account name. + +```text +$ vault write totp/keys/test \ + generate=true issuer=Vault account_name=test@gmail.com +``` +A base64 encoded barcode and url will be returned upon generating a new key. These can be given to client applications that +can generate passwords. You can validate those passwords by writing to the `code/test` path. + +```text +$ vault write totp/code/test \ + code=127388 +Key Value +valid true +``` + +If you get stuck at any time, simply run `vault path-help totp` or with a +subpath for interactive help output. + +## API + +The TOTP secret backend has a full HTTP API. Please see the +[TOTP secret backend API](/api/secret/totp/index.html) for more +details. diff --git a/website/source/layouts/api.erb b/website/source/layouts/api.erb index b4429a02b..989c60bd8 100644 --- a/website/source/layouts/api.erb +++ b/website/source/layouts/api.erb @@ -53,6 +53,9 @@ > SSH + > + TOTP + > Transit diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index f5317171f..3ff5dc531 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -247,6 +247,10 @@ SSH + > + TOTP + + > Transit From 82b58d5b9cea02f673d30d2da4fe30afde6606ae Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 11:45:27 -0700 Subject: [PATCH 149/162] Update docs and return a better error message --- builtin/logical/database/backend.go | 2 +- website/source/api/secret/databases/index.html.md | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 3d1502805..91b92e438 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -106,7 +106,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + return nil, fmt.Errorf("failed to read connection configuration: %s", err) } if entry == nil { return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md index f55998ace..d43e49789 100644 --- a/website/source/api/secret/databases/index.html.md +++ b/website/source/api/secret/databases/index.html.md @@ -162,11 +162,13 @@ This endpoint creates or updates a role definition. - `db_name` `(string: )` - The name of the database connection to use for this role. -- `default_ttl` `(string: )` - Specifies the TTL for the lease - associated with this role. +- `default_ttl` `(string/int: 0)` - Specifies the TTL for the leases + associated with this role. Accepts time suffixed strings ("1h") or an integer + number of seconds. Defaults to system/backend default TTL time. -- `max_ttl` `(string: )` - Specifies the maximum TTL for the lease - associated with this role. +- `max_ttl` `(string/int: 0)` - Specifies the maximum TTL for the leases + associated with this role. Accepts time suffixed strings ("1h") or an integer + number of seconds. Defaults to system/backend default TTL time. - `creation_statements` `(string: )` – Specifies the database statements executed to create and configure a user. Must be a From 2d6dfbf14762e8cf92b06003c7a0fc54a2604d9d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 12:36:06 -0700 Subject: [PATCH 150/162] Don't store the plugin directory prepended command in the barrier, prepend on get --- vault/plugin_catalog.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 095d81b1e..79474e601 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -59,6 +59,9 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("failed to decode plugin entry: %v", err) } + // prepend the plugin directory to the command + entry.Command = filepath.Join(c.directory, entry.Command) + return entry, nil } } @@ -85,14 +88,11 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { defer c.lock.Unlock() parts := strings.Split(command, " ") - command = parts[0] - args := parts[1:] - - command = filepath.Join(c.directory, command) // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. - sym, err := filepath.EvalSymlinks(command) + commandFull := filepath.Join(c.directory, parts[0]) + sym, err := filepath.EvalSymlinks(commandFull) if err != nil { return fmt.Errorf("error while validating the command path: %v", err) } @@ -107,8 +107,8 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { entry := &pluginutil.PluginRunner{ Name: name, - Command: command, - Args: args, + Command: parts[0], + Args: parts[1:], Sha256: sha256, Builtin: false, } From 812841ff38679e44e4c6dab98c0aa6df13762199 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 12:48:52 -0700 Subject: [PATCH 151/162] changelog++ --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4546a9e9e..cbc9e945b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,10 @@ FEATURES: revoke-force) have also been relocated to `sys/leases/`, but they also work at the old paths for compatibility. Reading (but not listing) leases via `sys/leases/lookup` is now a part of the current `default` policy. [GH-2650] + * **TOTP Secret Backend**: You can now store multi-factor authentication keys + in Vault and use the API to retrieve time-based one-time use passwords on + demand. The backend can also be used to generate a new key and validate + passwords generated by that key. IMPROVEMENTS: From a255cd9b93a83af38e9995f627195270e47a3a0a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 12:49:47 -0700 Subject: [PATCH 152/162] changelog++ --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbc9e945b..5fb97821f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ FEATURES: * **TOTP Secret Backend**: You can now store multi-factor authentication keys in Vault and use the API to retrieve time-based one-time use passwords on demand. The backend can also be used to generate a new key and validate - passwords generated by that key. + passwords generated by that key. [GH-2492] IMPROVEMENTS: From 3c41bdfa16f66a96b6362152683eaace624acdd2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 13:38:49 -0700 Subject: [PATCH 153/162] update docs --- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index f4cf3640b..0d7c49748 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -15,9 +15,6 @@ The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MySQL database. -See the [Database Backend](/docs/secrets/databases/index.html) docs for more -information about setting up the Database Backend. - This plugin has a few different instances built into vault, each instance is for a slightly different MySQL driver. The only difference between these plugins is the length of usernames generated by the plugin as different versions of mysql @@ -28,6 +25,9 @@ accept different lengths. The availible plugins are: - mysql-rds-database-plugin - mysql-legacy-database-plugin +See the [Database Backend](/docs/secrets/databases/index.html) docs for more +information about setting up the Database Backend. + ## Quick Start After the Database Backend is mounted you can configure a MySQL connection From b49993f81f2f1d7d8cef404af09bb58f6c80bc01 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Thu, 4 May 2017 16:46:34 -0400 Subject: [PATCH 154/162] Update mssql docs --- website/source/docs/secrets/databases/mssql.html.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index c2f7ff5fe..63ea31c44 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -39,11 +39,11 @@ Once the MSSQL connection is configured we can add a role: $ vault write database/roles/readonly \ db_name=mssql \ creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ - USE AdventureWorks; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; \ - GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ + CREATE USER [{{name}}] FOR LOGIN [{{name}}];\ + GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];" \ default_ttl="1h" \ max_ttl="24h" - + Success! Data written to: database/roles/readonly ``` From f01b413d8de5a429692350830ce11b1d7041e9bc Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 4 May 2017 16:58:50 -0400 Subject: [PATCH 155/162] Make path-help request forward (#2677) --- http/forwarding_test.go | 30 ++++++++++++++++++++++++++++++ http/help.go | 12 ++++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/http/forwarding_test.go b/http/forwarding_test.go index fdc3b76fc..e7d285e24 100644 --- a/http/forwarding_test.go +++ b/http/forwarding_test.go @@ -595,3 +595,33 @@ func TestHTTP_Forwarding_ClientTLS(t *testing.T) { } } } + +func TestHTTP_Forwarding_HelpOperation(t *testing.T) { + handler1 := http.NewServeMux() + handler2 := http.NewServeMux() + handler3 := http.NewServeMux() + + cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, &vault.CoreConfig{}, true) + for _, core := range cores { + defer core.CloseListeners() + } + + handler1.Handle("/", Handler(cores[0].Core)) + handler2.Handle("/", Handler(cores[1].Core)) + handler3.Handle("/", Handler(cores[2].Core)) + + vault.TestWaitActive(t, cores[0].Core) + + testHelp := func(client *api.Client) { + help, err := client.Help("auth/token") + if err != nil { + t.Fatal(err) + } + if help == nil { + t.Fatal("help was nil") + } + } + + testHelp(cores[0].Client) + testHelp(cores[1].Client) +} diff --git a/http/help.go b/http/help.go index f0ca8b170..1c3a9560f 100644 --- a/http/help.go +++ b/http/help.go @@ -8,14 +8,18 @@ import ( ) func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // If the help parameter is not blank, then show the help + return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + // If the help parameter is not blank, then show the help. We request + // forward because standby nodes do not have mounts and other state. if v := req.URL.Query().Get("help"); v != "" || req.Method == "HELP" { - handleHelp(core, w, req) + handleRequestForwarding(core, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleHelp(core, w, r) + })).ServeHTTP(writer, req) return } - h.ServeHTTP(w, req) + h.ServeHTTP(writer, req) return }) } From 13940b59e4fa2257e787cbe44b96560515d2f825 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 4 May 2017 16:59:44 -0400 Subject: [PATCH 156/162] changelog++ --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fb97821f..7bb5f3e89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,8 @@ BUG FIXES: * auth/ldap: Don't lowercase groups attached to users [GH-2613] * cli: Don't panic if `vault write` is used with the `force` flag but no path [GH-2674] + * core: Help operations should request forward since standbys may not have + appropriate info [GH-2677] * secret/mssql: Update mssql driver to support queries with colons [GH-2610] * secret/pki: Don't lowercase O/OU values in certs [GH-2555] * secret/pki: Don't attempt to validate IP SANs if none are provided [GH-2574] From 16e6f9640d538919164617104b49ea853bd00845 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 14:07:12 -0700 Subject: [PATCH 157/162] Few docs updates --- website/source/docs/secrets/databases/mssql.html.md | 2 +- website/source/docs/secrets/databases/postgresql.html.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 63ea31c44..889e35a4f 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -40,7 +40,7 @@ $ vault write database/roles/readonly \ db_name=mssql \ creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ CREATE USER [{{name}}] FOR LOGIN [{{name}}];\ - GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];" \ + GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ default_ttl="1h" \ max_ttl="24h" diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index e04cc087c..b2c0c7bb2 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -27,7 +27,7 @@ configuration: $ vault write database/config/postgresql \ plugin_name=postgresql-database-plugin \ allowed_roles="readonly" \ - connection_url="postgresql://root:root@localhost:5432/postgres" + connection_url="postgresql://root:root@localhost:5432/" The following warnings were returned from the Vault server: * Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. From 24b27f99f8c9bbc29dd2ab734073709131105003 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 15:27:43 -0700 Subject: [PATCH 158/162] Changelog++ --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bb5f3e89..384e5efe5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ FEATURES: in Vault and use the API to retrieve time-based one-time use passwords on demand. The backend can also be used to generate a new key and validate passwords generated by that key. [GH-2492] + * **Database Secret Backend & Secure Plugins**: This new secret backend + combines the functionality of the MySQL, PostgreSQL, MSSQL, and Cassandra + backends. It also provides a plugin interface for extendability through + custom databases. IMPROVEMENTS: From 0a8ee020650b3688c77bae9d124460292b79a140 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 15:29:05 -0700 Subject: [PATCH 159/162] changelog++ --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 384e5efe5..68407103e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ FEATURES: * **Database Secret Backend & Secure Plugins**: This new secret backend combines the functionality of the MySQL, PostgreSQL, MSSQL, and Cassandra backends. It also provides a plugin interface for extendability through - custom databases. + custom databases. [GH-2200] IMPROVEMENTS: From 913a11268190a846c40c920db15e825dbc957463 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 17:55:30 -0700 Subject: [PATCH 160/162] Update mssql.html.md --- website/source/api/secret/databases/mssql.html.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/source/api/secret/databases/mssql.html.md b/website/source/api/secret/databases/mssql.html.md index 09893df45..d4b120e8d 100644 --- a/website/source/api/secret/databases/mssql.html.md +++ b/website/source/api/secret/databases/mssql.html.md @@ -25,8 +25,8 @@ has a number of parameters to further configure a connection. ### Parameters - `connection_url` `(string: )` - Specifies the MSSQL DSN. -- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use - for this connection. +- `max_open_connections` `(int: 2)` - Specifies the maximum number of open + connections to the database. - `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle connections to the database. A zero uses the value of `max_open_connections` From 20cc43bf18d3f6ea3f46a051b4b9928d3c564aa8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 17:55:50 -0700 Subject: [PATCH 161/162] Update mysql-maria.html.md --- website/source/api/secret/databases/mysql-maria.html.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/source/api/secret/databases/mysql-maria.html.md b/website/source/api/secret/databases/mysql-maria.html.md index 981506798..b4657eaa6 100644 --- a/website/source/api/secret/databases/mysql-maria.html.md +++ b/website/source/api/secret/databases/mysql-maria.html.md @@ -25,8 +25,8 @@ has a number of parameters to further configure a connection. ### Parameters - `connection_url` `(string: )` - Specifies the MySQL DSN. -- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use - for this connection. +- `max_open_connections` `(int: 2)` - Specifies the maximum number of open + connections to the database. - `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle connections to the database. A zero uses the value of `max_open_connections` From 61f115ba8137463947bad9cfd77b26a9e7e37c34 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 17:56:09 -0700 Subject: [PATCH 162/162] Update postgresql.html.md --- website/source/api/secret/databases/postgresql.html.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/source/api/secret/databases/postgresql.html.md b/website/source/api/secret/databases/postgresql.html.md index 5ff6b8022..a1aaeee1c 100644 --- a/website/source/api/secret/databases/postgresql.html.md +++ b/website/source/api/secret/databases/postgresql.html.md @@ -25,8 +25,8 @@ has a number of parameters to further configure a connection. ### Parameters - `connection_url` `(string: )` - Specifies the PostgreSQL DSN. -- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use - for this connection. +- `max_open_connections` `(int: 2)` - Specifies the maximum number of open + connections to the database. - `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle connections to the database. A zero uses the value of `max_open_connections`