Add arbitrary string slice parsing.
Like the KV function, this supports either separated strings or JSON strings, base64-encoded or not. Fixes #1619 in theory.
This commit is contained in:
parent
c025b292b5
commit
9e204bd88c
|
@ -1452,7 +1452,7 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
|
|||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
if err = strutil.ParseArbitraryKeyValues(data.Get("metadata").(string), secretIDStorage.Metadata); err != nil {
|
||||
if err = strutil.ParseArbitraryKeyValues(data.Get("metadata").(string), secretIDStorage.Metadata, ","); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -61,13 +62,23 @@ func (b *backend) pathCredsCreateRead(
|
|||
}
|
||||
|
||||
// Execute each query
|
||||
for _, query := range splitSQL(role.CreationCQL) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.CreationCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = session.Query(substQuery(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range splitSQL(role.RollbackCQL) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.RollbackCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
session.Query(substQuery(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
|
|
|
@ -12,19 +12,6 @@ import (
|
|||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// SplitSQL is used to split a series of SQL statements
|
||||
func splitSQL(sql string) []string {
|
||||
parts := strings.Split(sql, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
clean := strings.TrimSpace(p)
|
||||
if len(clean) > 0 {
|
||||
out = append(out, clean)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func substQuery(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
|
|
|
@ -2,8 +2,10 @@ package mssql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -82,7 +84,12 @@ func (b *backend) pathCredsCreateRead(
|
|||
roleSQL := fmt.Sprintf("USE [%s]; %s", b.defaultDb, role.SQL)
|
||||
|
||||
// Execute each query
|
||||
for _, query := range SplitSQL(roleSQL) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(roleSQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(Query(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
|
|
|
@ -2,7 +2,9 @@ package mssql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -112,7 +114,12 @@ func (b *backend) pathRoleCreate(
|
|||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
for _, query := range SplitSQL(sql) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
|
|
|
@ -2,8 +2,10 @@ package mysql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -95,7 +97,12 @@ func (b *backend) pathRoleCreateRead(
|
|||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range SplitSQL(role.SQL) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(Query(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
|
|
|
@ -2,8 +2,10 @@ package mysql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -140,7 +142,12 @@ func (b *backend) pathRoleCreate(
|
|||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
for _, query := range SplitSQL(sql) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
|
|
|
@ -5,19 +5,6 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// SplitSQL is used to split a series of SQL statements
|
||||
func SplitSQL(sql string) []string {
|
||||
parts := strings.Split(sql, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
clean := strings.TrimSpace(p)
|
||||
if len(clean) > 0 {
|
||||
out = append(out, clean)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func Query(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
|
|
|
@ -2,9 +2,11 @@ package postgresql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -100,7 +102,12 @@ func (b *backend) pathRoleCreateRead(
|
|||
}()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range SplitSQL(role.SQL) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
b.logger.Println("[TRACE] postgres/pathRoleCreateRead: preparing statement")
|
||||
stmt, err := tx.Prepare(Query(query, map[string]string{
|
||||
"name": username,
|
||||
|
|
|
@ -2,7 +2,9 @@ package postgresql
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -112,7 +114,12 @@ func (b *backend) pathRoleCreate(
|
|||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
for _, query := range SplitSQL(sql) {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import "strings"
|
||||
|
||||
// SplitSQL is used to split a series of SQL statements
|
||||
func SplitSQL(sql string) []string {
|
||||
parts := strings.Split(sql, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
clean := strings.TrimSpace(p)
|
||||
if len(clean) > 0 {
|
||||
out = append(out, clean)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
|
@ -32,20 +32,24 @@ func StrListSubset(super, sub []string) bool {
|
|||
// Parses a comma separated list of strings into a slice of strings.
|
||||
// The return slice will be sorted and will not contain duplicate or
|
||||
// empty items. The values will be converted to lower case.
|
||||
func ParseDedupAndSortStrings(input string) []string {
|
||||
func ParseDedupAndSortStrings(input string, sep string) []string {
|
||||
input = strings.TrimSpace(input)
|
||||
var parsed []string
|
||||
if input == "" {
|
||||
// Don't return nil
|
||||
return parsed
|
||||
}
|
||||
return RemoveDuplicates(strings.Split(input, ","))
|
||||
return RemoveDuplicates(strings.Split(input, sep))
|
||||
}
|
||||
|
||||
// Parses a comma separated list of `<key>=<value>` tuples into a
|
||||
// map[string]string.
|
||||
func ParseKeyValues(input string, out map[string]string) error {
|
||||
keyValues := ParseDedupAndSortStrings(input)
|
||||
func ParseKeyValues(input string, out map[string]string, sep string) error {
|
||||
if out == nil {
|
||||
return fmt.Errorf("'out is nil")
|
||||
}
|
||||
|
||||
keyValues := ParseDedupAndSortStrings(input, sep)
|
||||
if len(keyValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -72,7 +76,7 @@ func ParseKeyValues(input string, out map[string]string) error {
|
|||
//
|
||||
// Input will be parsed into the output paramater, which should
|
||||
// be a non-nil map[string]string.
|
||||
func ParseArbitraryKeyValues(input string, out map[string]string) error {
|
||||
func ParseArbitraryKeyValues(input string, out map[string]string, sep string) error {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil
|
||||
|
@ -94,7 +98,7 @@ func ParseArbitraryKeyValues(input string, out map[string]string) error {
|
|||
if err != nil {
|
||||
// If JSON unmarshalling fails, consider that the input was
|
||||
// supplied as a comma separated string of 'key=value' pairs.
|
||||
if err = ParseKeyValues(input, out); err != nil {
|
||||
if err = ParseKeyValues(input, out, sep); err != nil {
|
||||
return fmt.Errorf("failed to parse the input: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -109,6 +113,71 @@ func ParseArbitraryKeyValues(input string, out map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Parses a `sep`-separated list of strings into a
|
||||
// []string.
|
||||
//
|
||||
// The output will always be a valid slice but may be of length zero.
|
||||
func ParseStringSlice(input string, sep string) []string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
splitStr := strings.Split(input, sep)
|
||||
ret := make([]string, len(splitStr))
|
||||
for i, val := range splitStr {
|
||||
ret[i] = val
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Parses arbitrary string slice. The input can be one of
|
||||
// the following:
|
||||
// * JSON string
|
||||
// * Base64 encoded JSON string
|
||||
// * `sep` separated list of values
|
||||
// * Base64-encoded string containting a `sep` separated list of values
|
||||
//
|
||||
// Note that the separator is ignored if the input is found to already be in a
|
||||
// structured format (e.g., JSON)
|
||||
//
|
||||
// The output will always be a valid slice but may be of length zero.
|
||||
func ParseArbitraryStringSlice(input string, sep string) []string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Try to base64 decode the input. If successful, consider the decoded
|
||||
// value as input.
|
||||
inputBytes, err := base64.StdEncoding.DecodeString(input)
|
||||
if err == nil {
|
||||
input = string(inputBytes)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
Ret []string
|
||||
}
|
||||
|
||||
var outD d
|
||||
|
||||
// Try to JSON unmarshal the input. If successful, consider that the
|
||||
// metadata was supplied as JSON input.
|
||||
err = json.Unmarshal([]byte(input), &outD)
|
||||
if err != nil {
|
||||
// If JSON unmarshalling fails, consider that the input was
|
||||
// supplied as a separated string of values.
|
||||
return ParseStringSlice(input, sep)
|
||||
}
|
||||
|
||||
if outD.Ret == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return outD.Ret
|
||||
}
|
||||
|
||||
// Removes duplicate and empty elements from a slice of strings.
|
||||
// This also converts the items in the slice to lower case and
|
||||
// returns a sorted slice.
|
||||
|
|
|
@ -75,7 +75,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
|
|||
var err error
|
||||
|
||||
input = "key1=value1,key2=value2"
|
||||
err = ParseKeyValues(input, actual)
|
||||
err = ParseKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
|
|||
}
|
||||
|
||||
input = "key1 = value1, key2 = value2"
|
||||
err = ParseKeyValues(input, actual)
|
||||
err = ParseKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
|
|||
}
|
||||
|
||||
input = "key1 = value1, key2 = "
|
||||
err = ParseKeyValues(input, actual)
|
||||
err = ParseKeyValues(input, actual, ",")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func TestStrutil_ParseKeyValues(t *testing.T) {
|
|||
}
|
||||
|
||||
input = "key1 = value1, = value2 "
|
||||
err = ParseKeyValues(input, actual)
|
||||
err = ParseKeyValues(input, actual, ",")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
|
|||
|
||||
// Test <key>=<value> as comma separated string
|
||||
input = "key1=value1,key2=value2"
|
||||
err = ParseArbitraryKeyValues(input, actual)
|
||||
err = ParseArbitraryKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
|
|||
|
||||
// Test <key>=<value> as base64 encoded comma separated string
|
||||
input = base64.StdEncoding.EncodeToString([]byte(input))
|
||||
err = ParseArbitraryKeyValues(input, actual)
|
||||
err = ParseArbitraryKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
|
|||
|
||||
// Test JSON encoded <key>=<value> tuples
|
||||
input = `{"key1":"value1", "key2":"value2"}`
|
||||
err = ParseArbitraryKeyValues(input, actual)
|
||||
err = ParseArbitraryKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func TestStrutil_ParseArbitraryKeyValues(t *testing.T) {
|
|||
|
||||
// Test base64 encoded JSON string of <key>=<value> tuples
|
||||
input = base64.StdEncoding.EncodeToString([]byte(input))
|
||||
err = ParseArbitraryKeyValues(input, actual)
|
||||
err = ParseArbitraryKeyValues(input, actual, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -184,7 +184,7 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro
|
|||
kv: client.KV(),
|
||||
permitPool: NewPermitPool(maxParInt),
|
||||
serviceName: service,
|
||||
serviceTags: strutil.ParseDedupAndSortStrings(tags),
|
||||
serviceTags: strutil.ParseDedupAndSortStrings(tags, ","),
|
||||
checkTimeout: checkTimeout,
|
||||
disableRegistration: disableRegistration,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue