Added unit tests for token entry upgrade

This commit is contained in:
vishalnayak 2016-09-26 18:17:50 -04:00
parent af888573be
commit 57b21acabb
2 changed files with 170 additions and 2 deletions

View File

@ -470,7 +470,7 @@ type TokenEntry struct {
TTL time.Duration `json:"ttl" mapstructure:"ttl" structs:"ttl"`
// Explicit maximum TTL on the token
ExplicitMaxTTL time.Duration `json:"" mapstructure:"" structs:""`
ExplicitMaxTTL time.Duration `json:"explicit_max_ttl" mapstructure:"explicit_max_ttl" structs:"explicit_max_ttl"`
// If set, the role that was used for parameters at creation time
Role string `json:"role" mapstructure:"role" structs:"role"`
@ -806,7 +806,7 @@ func (ts *TokenStore) lookupSalted(saltedId string) (*TokenEntry, error) {
// Upgrade the deprecated fields
if entry.DisplayNameDeprecated != "" {
if entry.DisplayName != "" {
if entry.DisplayName == "" {
entry.DisplayName = entry.DisplayNameDeprecated
}
entry.DisplayNameDeprecated = ""

View File

@ -1,15 +1,183 @@
package vault
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"testing"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
)
type TokenEntryOld struct {
ID string
Accessor string
Parent string
Policies []string
Path string
Meta map[string]string
DisplayName string
NumUses int
CreationTime int64
TTL time.Duration
ExplicitMaxTTL time.Duration
Role string
Period time.Duration
}
func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
var err error
_, ts, _, _ := TestCoreWithTokenStore(t)
// Use a struct that does not have struct tags to store the items and
// check if the lookup code handles them properly while reading back
entry := &TokenEntryOld{
DisplayName: "test-display-name",
Path: "test",
Policies: []string{"dev", "ops"},
CreationTime: time.Now().Unix(),
ExplicitMaxTTL: 100,
NumUses: 10,
}
entry.ID, err = uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
enc, err := json.Marshal(entry)
if err != nil {
t.Fatal(err)
}
saltedId := ts.SaltID(entry.ID)
path := lookupPrefix + saltedId
le := &logical.StorageEntry{
Key: path,
Value: enc,
}
if err := ts.view.Put(le); err != nil {
t.Fatal(err)
}
out, err := ts.Lookup(entry.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
if out.DisplayName != "test-display-name" {
t.Fatalf("bad: display_name: expected: test-display-name, actual: %s", out.DisplayName)
}
if out.CreationTime == 0 {
t.Fatal("bad: expected a non-zero creation time")
}
if out.ExplicitMaxTTL != 100 {
t.Fatalf("bad: explicit_max_ttl: expected: 100, actual: %d", out.ExplicitMaxTTL)
}
if out.NumUses != 10 {
t.Fatalf("bad: num_uses: expected: 10, actual: %d", out.NumUses)
}
// Test the default case to ensure there are no regressions
ent := &TokenEntry{
DisplayName: "test-display-name",
Path: "test",
Policies: []string{"dev", "ops"},
CreationTime: time.Now().Unix(),
ExplicitMaxTTL: 100,
NumUses: 10,
}
if err := ts.create(ent); err != nil {
t.Fatalf("err: %s", err)
}
out, err = ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
if out.DisplayName != "test-display-name" {
t.Fatalf("bad: display_name: expected: test-display-name, actual: %s", out.DisplayName)
}
if out.CreationTime == 0 {
t.Fatal("bad: expected a non-zero creation time")
}
if out.ExplicitMaxTTL != 100 {
t.Fatalf("bad: explicit_max_ttl: expected: 100, actual: %d", out.ExplicitMaxTTL)
}
if out.NumUses != 10 {
t.Fatalf("bad: num_uses: expected: 10, actual: %d", out.NumUses)
}
// Fill in the deprecated fields and read out from proper fields
ent = &TokenEntry{
Path: "test",
Policies: []string{"dev", "ops"},
DisplayNameDeprecated: "test-display-name",
CreationTimeDeprecated: time.Now().Unix(),
ExplicitMaxTTLDeprecated: 100,
NumUsesDeprecated: 10,
}
if err := ts.create(ent); err != nil {
t.Fatalf("err: %s", err)
}
out, err = ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
if out.DisplayName != "test-display-name" {
t.Fatalf("bad: display_name: expected: test-display-name, actual: %s", out.DisplayName)
}
if out.CreationTime == 0 {
t.Fatal("bad: expected a non-zero creation time")
}
if out.ExplicitMaxTTL != 100 {
t.Fatalf("bad: explicit_max_ttl: expected: 100, actual: %d", out.ExplicitMaxTTL)
}
if out.NumUses != 10 {
t.Fatalf("bad: num_uses: expected: 10, actual: %d", out.NumUses)
}
// Check if NumUses picks up a lower value
ent = &TokenEntry{
Path: "test",
NumUses: 5,
NumUsesDeprecated: 10,
}
if err := ts.create(ent); err != nil {
t.Fatalf("err: %s", err)
}
out, err = ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
if out.NumUses != 5 {
t.Fatalf("bad: num_uses: expected: 5, actual: %d", out.NumUses)
}
// Switch the values from deprecated and proper field and check if the
// lower value is still getting picked up
ent = &TokenEntry{
Path: "test",
NumUses: 10,
NumUsesDeprecated: 5,
}
if err := ts.create(ent); err != nil {
t.Fatalf("err: %s", err)
}
out, err = ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
if out.NumUses != 5 {
t.Fatalf("bad: num_uses: expected: 5, actual: %d", out.NumUses)
}
}
func getBackendConfig(c *Core) *logical.BackendConfig {
return &logical.BackendConfig{
Logger: c.logger,