diff --git a/sdk/database/helper/credsutil/sql.go b/sdk/database/helper/credsutil/sql.go index 986631da9..39fb467a7 100644 --- a/sdk/database/helper/credsutil/sql.go +++ b/sdk/database/helper/credsutil/sql.go @@ -2,8 +2,6 @@ package credsutil import ( "context" - "fmt" - "strings" "time" "github.com/hashicorp/vault/sdk/database/dbplugin" @@ -31,46 +29,17 @@ func (scp *SQLCredentialsProducer) GenerateCredentials(ctx context.Context) (str } func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) { - username := "v" - - displayName := config.DisplayName - if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen { - displayName = displayName[:scp.DisplayNameLen] - } else if scp.DisplayNameLen == NoneLength { - displayName = "" - } - - if len(displayName) > 0 { - username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName) - } - - roleName := config.RoleName - if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen { - roleName = roleName[:scp.RoleNameLen] - } else if scp.RoleNameLen == NoneLength { - roleName = "" - } - - if len(roleName) > 0 { - username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName) - } - - userUUID, err := RandomAlphaNumeric(20, false) - if err != nil { - return "", err - } - - username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID) - username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix())) - if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { - username = username[:scp.UsernameLen] - } - + caseOp := KeepCase if scp.LowercaseUsername { - username = strings.ToLower(username) + caseOp = Lowercase } - - return username, nil + return GenerateUsername( + DisplayName(config.DisplayName, scp.DisplayNameLen), + RoleName(config.RoleName, scp.RoleNameLen), + Case(caseOp), + Separator(scp.Separator), + MaxLength(scp.UsernameLen), + ) } func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) { diff --git a/sdk/database/helper/credsutil/usernames.go b/sdk/database/helper/credsutil/usernames.go new file mode 100644 index 000000000..c1e3ccb52 --- /dev/null +++ b/sdk/database/helper/credsutil/usernames.go @@ -0,0 +1,140 @@ +package credsutil + +import ( + "fmt" + "strings" + "time" +) + +type CaseOp int + +const ( + KeepCase CaseOp = iota + Uppercase + Lowercase +) + +type usernameBuilder struct { + displayName string + roleName string + separator string + + maxLen int + caseOperation CaseOp +} + +func (ub usernameBuilder) makeUsername() (string, error) { + userUUID, err := RandomAlphaNumeric(20, false) + if err != nil { + return "", err + } + + now := fmt.Sprint(time.Now().Unix()) + + username := joinNonEmpty(ub.separator, + "v", + ub.displayName, + ub.roleName, + userUUID, + now) + username = trunc(username, ub.maxLen) + switch ub.caseOperation { + case Lowercase: + username = strings.ToLower(username) + case Uppercase: + username = strings.ToUpper(username) + } + + return username, nil +} + +type UsernameOpt func(*usernameBuilder) + +func DisplayName(dispName string, maxLength int) UsernameOpt { + return func(b *usernameBuilder) { + b.displayName = trunc(dispName, maxLength) + } +} + +func RoleName(roleName string, maxLength int) UsernameOpt { + return func(b *usernameBuilder) { + b.roleName = trunc(roleName, maxLength) + } +} + +func Separator(sep string) UsernameOpt { + return func(b *usernameBuilder) { + b.separator = sep + } +} + +func MaxLength(maxLen int) UsernameOpt { + return func(b *usernameBuilder) { + b.maxLen = maxLen + } +} + +func Case(c CaseOp) UsernameOpt { + return func(b *usernameBuilder) { + b.caseOperation = c + } +} + +func ToLower() UsernameOpt { + return Case(Lowercase) +} + +func ToUpper() UsernameOpt { + return Case(Uppercase) +} + +func GenerateUsername(opts ...UsernameOpt) (string, error) { + b := usernameBuilder{ + separator: "_", + maxLen: 100, + caseOperation: KeepCase, + } + + for _, opt := range opts { + opt(&b) + } + + return b.makeUsername() +} + +func trunc(str string, l int) string { + switch { + case l > 0: + if l > len(str) { + return str + } + return str[:l] + case l == 0: + return str + default: + return "" + } +} + +func joinNonEmpty(sep string, vals ...string) string { + if sep == "" { + return strings.Join(vals, sep) + } + switch len(vals) { + case 0: + return "" + case 1: + return vals[0] + } + builder := &strings.Builder{} + for _, val := range vals { + if val == "" { + continue + } + if builder.Len() > 0 { + builder.WriteString(sep) + } + builder.WriteString(val) + } + return builder.String() +} diff --git a/sdk/database/helper/credsutil/usernames_test.go b/sdk/database/helper/credsutil/usernames_test.go new file mode 100644 index 000000000..b1e79ce26 --- /dev/null +++ b/sdk/database/helper/credsutil/usernames_test.go @@ -0,0 +1,144 @@ +package credsutil + +import ( + "regexp" + "testing" +) + +func TestGenerateUsername(t *testing.T) { + type testCase struct { + displayName string + displayNameLen int + + roleName string + roleNameLen int + + usernameLen int + separator string + caseOp CaseOp + + regex string + } + tests := map[string]testCase{ + "all opts": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 10, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 10, + usernameLen: 45, + separator: ".", + caseOp: KeepCase, + + regex: "^v.abcdefghij.zyxwvutsrq.[a-zA-Z0-9]{20}.$", + }, + "no separator": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 10, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 10, + usernameLen: 45, + separator: "", + caseOp: KeepCase, + + regex: "^vabcdefghijzyxwvutsrq[a-zA-Z0-9]{20}[0-9]{4}$", + }, + "lowercase": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 10, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 10, + usernameLen: 45, + separator: "_", + caseOp: Lowercase, + + regex: "^v_abcdefghij_zyxwvutsrq_[a-z0-9]{20}_$", + }, + "uppercase": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 10, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 10, + usernameLen: 45, + separator: "_", + caseOp: Uppercase, + + regex: "^V_ABCDEFGHIJ_ZYXWVUTSRQ_[A-Z0-9]{20}_$", + }, + "short username": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 5, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 5, + usernameLen: 15, + separator: "_", + caseOp: KeepCase, + + regex: "^v_abcde_zyxwv_[a-zA-Z0-9]{1}$", + }, + "long username": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 0, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 0, + usernameLen: 100, + separator: "_", + caseOp: KeepCase, + + regex: "^v_abcdefghijklmonpqrstuvwxyz_zyxwvutsrqpnomlkjihgfedcba_[a-zA-Z0-9]{20}_[0-9]{1,23}$", + }, + "zero max length": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 0, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 0, + usernameLen: 0, + separator: "_", + caseOp: KeepCase, + + regex: "^v_abcdefghijklmonpqrstuvwxyz_zyxwvutsrqpnomlkjihgfedcba_[a-zA-Z0-9]{20}_[0-9]+$", + }, + "no display name": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: NoneLength, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: 15, + usernameLen: 100, + separator: "_", + caseOp: KeepCase, + + regex: "^v_zyxwvutsrqpnoml_[a-zA-Z0-9]{20}_[0-9]+$", + }, + "no role name": { + displayName: "abcdefghijklmonpqrstuvwxyz", + displayNameLen: 15, + roleName: "zyxwvutsrqpnomlkjihgfedcba", + roleNameLen: NoneLength, + usernameLen: 100, + separator: "_", + caseOp: KeepCase, + + regex: "^v_abcdefghijklmon_[a-zA-Z0-9]{20}_[0-9]+$", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + re := regexp.MustCompile(test.regex) + + username, err := GenerateUsername( + DisplayName(test.displayName, test.displayNameLen), + RoleName(test.roleName, test.roleNameLen), + Separator(test.separator), + MaxLength(test.usernameLen), + Case(test.caseOp), + ) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + + if !re.MatchString(username) { + t.Fatalf("username %q does not match regex %q", username, test.regex) + } + }) + } +}