diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index d605153a4..a72cdefd6 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -131,6 +131,21 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab return &config, nil } +type upgradeStatements struct { + // This json tag has a typo in it, the new version does not. This + // necessitates this upgrade logic. + CreationStatements string `json:"creation_statments"` + RevocationStatements string `json:"revocation_statements"` + RollbackStatements string `json:"rollback_statements"` + RenewStatements string `json:"renew_statements"` +} + +type upgradeCheck struct { + // This json tag has a typo in it, the new version does not. This + // necessitates this upgrade logic. + Statements upgradeStatements `json:"statments"` +} + func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) { entry, err := s.Get("role/" + roleName) if err != nil { @@ -140,11 +155,24 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, return nil, nil } + var upgradeCh upgradeCheck + if err := entry.DecodeJSON(&upgradeCh); err != nil { + return nil, err + } + var result roleEntry if err := entry.DecodeJSON(&result); err != nil { return nil, err } + empty := upgradeCheck{} + if upgradeCh != empty { + result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements + result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements + result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements + result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements + } + return &result, nil } diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 35d3639cd..64f5e868b 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -116,6 +116,55 @@ func TestBackend_PluginMain(t *testing.T) { postgresql.Run(apiClientMeta.GetTLSConfig()) } +func TestBackend_RoleUpgrade(t *testing.T) { + + storage := &logical.InmemStorage{} + backend := &databaseBackend{} + + roleEnt := &roleEntry{ + Statements: dbplugin.Statements{ + CreationStatements: "test", + }, + } + + entry, err := logical.StorageEntryJSON("role/test", roleEnt) + if err != nil { + t.Fatal(err) + } + if err := storage.Put(entry); err != nil { + t.Fatal(err) + } + + role, err := backend.Role(storage, "test") + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(role, roleEnt) { + t.Fatal("bad role %#v", role) + } + + // Upgrade case + badJSON := `{"statments":{"creation_statments":"test","revocation_statements":"","rollback_statements":"","renew_statements":""}}` + entry = &logical.StorageEntry{ + Key: "role/test", + Value: []byte(badJSON), + } + if err := storage.Put(entry); err != nil { + t.Fatal(err) + } + + role, err = backend.Role(storage, "test") + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(role, roleEnt) { + t.Fatal("bad role %#v", role) + } + +} + func TestBackend_config_connection(t *testing.T) { var resp *logical.Response var err error diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 69884cb3a..9404aee85 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -181,7 +181,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { type roleEntry struct { DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"` + Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"` DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` }