Add arbitrary string slice parsing.

Like the KV function, this supports either separated strings or JSON
strings, base64-encoded or not.

Fixes #1619 in theory.
This commit is contained in:
Jeff Mitchell 2016-08-03 14:18:22 -04:00
parent c025b292b5
commit 9e204bd88c
14 changed files with 146 additions and 66 deletions

View file

@ -1452,7 +1452,7 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
Metadata: make(map[string]string),
}
if err = strutil.ParseArbitraryKeyValues(data.Get("metadata").(string), secretIDStorage.Metadata); err != nil {
if err = strutil.ParseArbitraryKeyValues(data.Get("metadata").(string), secretIDStorage.Metadata, ","); err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil
}

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -61,13 +62,23 @@ func (b *backend) pathCredsCreateRead(
}
// Execute each query
for _, query := range splitSQL(role.CreationCQL) {
for _, query := range strutil.ParseArbitraryStringSlice(role.CreationCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err = session.Query(substQuery(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, query := range splitSQL(role.RollbackCQL) {
for _, query := range strutil.ParseArbitraryStringSlice(role.RollbackCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
session.Query(substQuery(query, map[string]string{
"username": username,
"password": password,

View file

@ -12,19 +12,6 @@ import (
"github.com/hashicorp/vault/logical"
)
// SplitSQL is used to split a series of SQL statements
func splitSQL(sql string) []string {
parts := strings.Split(sql, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
clean := strings.TrimSpace(p)
if len(clean) > 0 {
out = append(out, clean)
}
}
return out
}
// Query templates a query for us.
func substQuery(tpl string, data map[string]string) string {
for k, v := range data {

View file

@ -2,8 +2,10 @@ package mssql
import (
"fmt"
"strings"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -82,7 +84,12 @@ func (b *backend) pathCredsCreateRead(
roleSQL := fmt.Sprintf("USE [%s]; %s", b.defaultDb, role.SQL)
// Execute each query
for _, query := range SplitSQL(roleSQL) {
for _, query := range strutil.ParseArbitraryStringSlice(roleSQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(Query(query, map[string]string{
"name": username,
"password": password,

View file

@ -2,7 +2,9 @@ package mssql
import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -112,7 +114,12 @@ func (b *backend) pathRoleCreate(
}
// Test the query by trying to prepare it
for _, query := range SplitSQL(sql) {
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",

View file

@ -2,8 +2,10 @@ package mysql
import (
"fmt"
"strings"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
@ -95,7 +97,12 @@ func (b *backend) pathRoleCreateRead(
defer tx.Rollback()
// Execute each query
for _, query := range SplitSQL(role.SQL) {
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(Query(query, map[string]string{
"name": username,
"password": password,

View file

@ -2,8 +2,10 @@ package mysql
import (
"fmt"
"strings"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -140,7 +142,12 @@ func (b *backend) pathRoleCreate(
}
// Test the query by trying to prepare it
for _, query := range SplitSQL(sql) {
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",

View file

@ -5,19 +5,6 @@ import (
"strings"
)
// SplitSQL is used to split a series of SQL statements
func SplitSQL(sql string) []string {
parts := strings.Split(sql, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
clean := strings.TrimSpace(p)
if len(clean) > 0 {
out = append(out, clean)
}
}
return out
}
// Query templates a query for us.
func Query(tpl string, data map[string]string) string {
for k, v := range data {

View file

@ -2,9 +2,11 @@ package postgresql
import (
"fmt"
"strings"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
@ -100,7 +102,12 @@ func (b *backend) pathRoleCreateRead(
}()
// Execute each query
for _, query := range SplitSQL(role.SQL) {
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
b.logger.Println("[TRACE] postgres/pathRoleCreateRead: preparing statement")
stmt, err := tx.Prepare(Query(query, map[string]string{
"name": username,

View file

@ -2,7 +2,9 @@ package postgresql
import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -112,7 +114,12 @@ func (b *backend) pathRoleCreate(
}
// Test the query by trying to prepare it
for _, query := range SplitSQL(sql) {
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",

View file

@ -1,16 +0,0 @@
package postgresql
import "strings"
// SplitSQL is used to split a series of SQL statements
func SplitSQL(sql string) []string {
parts := strings.Split(sql, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
clean := strings.TrimSpace(p)
if len(clean) > 0 {
out = append(out, clean)
}
}
return out
}

View file

@ -32,20 +32,24 @@ func StrListSubset(super, sub []string) bool {
// Parses a comma separated list of strings into a slice of strings.
// The return slice will be sorted and will not contain duplicate or
// empty items. The values will be converted to lower case.
func ParseDedupAndSortStrings(input string) []string {
func ParseDedupAndSortStrings(input string, sep string) []string {
input = strings.TrimSpace(input)
var parsed []string
if input == "" {
// Don't return nil
return parsed
}
return RemoveDuplicates(strings.Split(input, ","))
return RemoveDuplicates(strings.Split(input, sep))
}
// Parses a comma separated list of `<key>=<value>` tuples into a
// map[string]string.
func ParseKeyValues(input string, out map[string]string) error {
keyValues := ParseDedupAndSortStrings(input)
func ParseKeyValues(input string, out map[string]string, sep string) error {
if out == nil {
return fmt.Errorf("'out is nil")
}
keyValues := ParseDedupAndSortStrings(input, sep)
if len(keyValues) == 0 {
return nil
}
@ -72,7 +76,7 @@ func ParseKeyValues(input string, out map[string]string) error {
//
// Input will be parsed into the output paramater, which should
// be a non-nil map[string]string.
func ParseArbitraryKeyValues(input string, out map[string]string) error {
func ParseArbitraryKeyValues(input string, out map[string]string, sep string) error {
input = strings.TrimSpace(input)
if input == "" {
return nil
@ -94,7 +98,7 @@ func ParseArbitraryKeyValues(input string, out map[string]string) error {
if err != nil {
// If JSON unmarshalling fails, consider that the input was
// supplied as a comma separated string of 'key=value' pairs.
if err = ParseKeyValues(input, out); err != nil {
if err = ParseKeyValues(input, out, sep); err != nil {
return fmt.Errorf("failed to parse the input: %v", err)
}
}
@ -109,6 +113,71 @@ func ParseArbitraryKeyValues(input string, out map[string]string) error {
return nil
}
// Parses a `sep`-separated list of strings into a
// []string.
//
// The output will always be a valid slice but may be of length zero.
func ParseStringSlice(input string, sep string) []string {
input = strings.TrimSpace(input)
if input == "" {
return []string{}
}
splitStr := strings.Split(input, sep)
ret := make([]string, len(splitStr))
for i, val := range splitStr {
ret[i] = val
}
return ret
}
// Parses arbitrary string slice. The input can be one of
// the following:
// * JSON string
// * Base64 encoded JSON string
// * `sep` separated list of values
// * Base64-encoded string containting a `sep` separated list of values
//
// Note that the separator is ignored if the input is found to already be in a
// structured format (e.g., JSON)
//
// The output will always be a valid slice but may be of length zero.
func ParseArbitraryStringSlice(input string, sep string) []string {
input = strings.TrimSpace(input)
if input == "" {
return []string{}
}
// Try to base64 decode the input. If successful, consider the decoded
// value as input.
inputBytes, err := base64.StdEncoding.DecodeString(input)
if err == nil {
input = string(inputBytes)
}
var d struct {
Ret []string
}
var outD d
// Try to JSON unmarshal the input. If successful, consider that the
// metadata was supplied as JSON input.
err = json.Unmarshal([]byte(input), &outD)
if err != nil {
// If JSON unmarshalling fails, consider that the input was
// supplied as a separated string of values.
return ParseStringSlice(input, sep)
}
if outD.Ret == nil {
return []string{}
}
return outD.Ret
}
// Removes duplicate and empty elements from a slice of strings.
// This also converts the items in the slice to lower case and
// returns a sorted slice.

View file

@ -75,7 +75,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
var err error
input = "key1=value1,key2=value2"
err = ParseKeyValues(input, actual)
err = ParseKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}
@ -87,7 +87,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
}
input = "key1 = value1, key2 = value2"
err = ParseKeyValues(input, actual)
err = ParseKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}
@ -99,7 +99,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
}
input = "key1 = value1, key2 = "
err = ParseKeyValues(input, actual)
err = ParseKeyValues(input, actual, ",")
if err == nil {
t.Fatal("expected an error")
}
@ -108,7 +108,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
}
input = "key1 = value1, = value2 "
err = ParseKeyValues(input, actual)
err = ParseKeyValues(input, actual, ",")
if err == nil {
t.Fatal("expected an error")
}
@ -128,7 +128,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
// Test <key>=<value> as comma separated string
input = "key1=value1,key2=value2"
err = ParseArbitraryKeyValues(input, actual)
err = ParseArbitraryKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}
@ -141,7 +141,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
// Test <key>=<value> as base64 encoded comma separated string
input = base64.StdEncoding.EncodeToString([]byte(input))
err = ParseArbitraryKeyValues(input, actual)
err = ParseArbitraryKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}
@ -154,7 +154,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
// Test JSON encoded <key>=<value> tuples
input = `{"key1":"value1", "key2":"value2"}`
err = ParseArbitraryKeyValues(input, actual)
err = ParseArbitraryKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}
@ -167,7 +167,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
// Test base64 encoded JSON string of <key>=<value> tuples
input = base64.StdEncoding.EncodeToString([]byte(input))
err = ParseArbitraryKeyValues(input, actual)
err = ParseArbitraryKeyValues(input, actual, ",")
if err != nil {
t.Fatal(err)
}

View file

@ -184,7 +184,7 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro
kv: client.KV(),
permitPool: NewPermitPool(maxParInt),
serviceName: service,
serviceTags: strutil.ParseDedupAndSortStrings(tags),
serviceTags: strutil.ParseDedupAndSortStrings(tags, ","),
checkTimeout: checkTimeout,
disableRegistration: disableRegistration,
}