package postgresql import ( "database/sql" "fmt" "log" "os" "testing" "github.com/hashicorp/vault/logical" logicaltest "github.com/hashicorp/vault/logical/testing" "github.com/lib/pq" "github.com/mitchellh/mapstructure" ) func TestBackend_basic(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) d1 := map[string]interface{}{ "connection_url": os.Getenv("PG_URL"), } d2 := map[string]interface{}{ "value": os.Getenv("PG_URL"), } logicaltest.Test(t, logicaltest.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, d1, false), testAccStepRole(t), testAccStepReadCreds(t, b, "web"), testAccStepConfig(t, d2, false), testAccStepRole(t), testAccStepReadCreds(t, b, "web"), }, }) } func TestBackend_roleCrud(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) d := map[string]interface{}{ "connection_url": os.Getenv("PG_URL"), } logicaltest.Test(t, logicaltest.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, d, false), testAccStepRole(t), testAccStepReadRole(t, "web", testRole), testAccStepDeleteRole(t, "web"), testAccStepReadRole(t, "web", ""), }, }) } func TestBackend_configConnection(t *testing.T) { b := Backend() d1 := map[string]interface{}{ "value": os.Getenv("PG_URL"), } d2 := map[string]interface{}{ "connection_url": os.Getenv("PG_URL"), } d3 := map[string]interface{}{ "value": os.Getenv("PG_URL"), "connection_url": os.Getenv("PG_URL"), } d4 := map[string]interface{}{} logicaltest.Test(t, logicaltest.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, d1, false), testAccStepConfig(t, d2, false), testAccStepConfig(t, d3, false), testAccStepConfig(t, d4, true), }, }) } func testAccPreCheck(t *testing.T) { if v := os.Getenv("PG_URL"); v == "" { t.Fatal("PG_URL must be set for acceptance tests") } } func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "config/connection", Data: d, ErrorOk: true, Check: func(resp *logical.Response) error { if expectError { if resp.Data == nil { return fmt.Errorf("data is nil") } var e struct { Error string `mapstructure:"error"` } if err := mapstructure.Decode(resp.Data, &e); err != nil { return err } if len(e.Error) == 0 { return fmt.Errorf("expected error, but write succeeded.") } return nil } else if resp != nil { return fmt.Errorf("response should be nil") } return nil }, } } func testAccStepRole(t *testing.T) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "roles/web", Data: map[string]interface{}{ "sql": testRole, }, } } func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.DeleteOperation, Path: "roles/" + n, } } func testAccStepReadCreds(t *testing.T, b logical.Backend, name string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.ReadOperation, Path: "creds/" + name, Check: func(resp *logical.Response) error { var d struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` } if err := mapstructure.Decode(resp.Data, &d); err != nil { return err } log.Printf("[WARN] Generated credentials: %v", d) conn, err := pq.ParseURL(os.Getenv("PG_URL")) if err != nil { t.Fatal(err) } conn += " timezone=utc" db, err := sql.Open("postgres", conn) if err != nil { t.Fatal(err) } returnedRows := func() int { stmt, err := db.Prepare(fmt.Sprintf( "SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege('%s', 'information_schema.role_column_grants', 'select');", d.Username)) if err != nil { return -1 } defer stmt.Close() rows, err := stmt.Query() if err != nil { return -1 } defer rows.Close() i := 0 for rows.Next() { i++ } return i } userRows := returnedRows() if userRows != 2 { t.Fatalf("did not get expected number of rows, got %d", userRows) } resp, err = b.HandleRequest(&logical.Request{ Operation: logical.RevokeOperation, Secret: &logical.Secret{ InternalData: map[string]interface{}{ "secret_type": "creds", "username": d.Username, }, }, }) if err != nil { return err } if resp != nil { if resp.IsError() { return fmt.Errorf("Error on resp: %#v", *resp) } } userRows = returnedRows() // User shouldn't exist so returnedRows() should encounter an error and exit with -1 if userRows != -1 { t.Fatalf("did not get expected number of rows, got %d", userRows) } return nil }, } } func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.ReadOperation, Path: "roles/" + name, Check: func(resp *logical.Response) error { if resp == nil { if sql == "" { return nil } return fmt.Errorf("bad: %#v", resp) } var d struct { SQL string `mapstructure:"sql"` } if err := mapstructure.Decode(resp.Data, &d); err != nil { return err } if d.SQL != sql { return fmt.Errorf("bad: %#v", resp) } return nil }, } } const testRole = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; `