From 4e7237178fd847a2aabfc68b78f1a4eaf6d87576 Mon Sep 17 00:00:00 2001 From: Becca Petrin Date: Mon, 13 Aug 2018 11:02:44 -0700 Subject: [PATCH] Add a header type field (#4993) --- logical/framework/backend.go | 3 + logical/framework/backend_test.go | 7 ++ logical/framework/field_data.go | 104 ++++++++++++++++++++++++++- logical/framework/field_data_test.go | 93 +++++++++++++++++++++++- logical/framework/field_type.go | 8 +++ 5 files changed, 212 insertions(+), 3 deletions(-) diff --git a/logical/framework/backend.go b/logical/framework/backend.go index 3c674e365..4c787cbed 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "net/http" "regexp" "sort" "strings" @@ -548,6 +549,8 @@ func (t FieldType) Zero() interface{} { return []string{} case TypeCommaIntSlice: return []int{} + case TypeHeader: + return http.Header{} default: panic("unknown type: " + t.String()) } diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index fa050ac60..09c6f9045 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "net/http" + "github.com/hashicorp/vault/logical" ) @@ -532,6 +534,11 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) { &FieldSchema{Type: TypeDurationSecond}, 0, }, + + "default header not set": { + &FieldSchema{Type: TypeHeader}, + http.Header{}, + }, } for name, tc := range cases { diff --git a/logical/framework/field_data.go b/logical/framework/field_data.go index 2bbb34805..2ee529116 100644 --- a/logical/framework/field_data.go +++ b/logical/framework/field_data.go @@ -1,9 +1,12 @@ package framework import ( + "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" + "net/http" "regexp" "strings" @@ -37,7 +40,7 @@ func (d *FieldData) Validate() error { switch schema.Type { case TypeBool, TypeInt, TypeMap, TypeDurationSecond, TypeString, TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, - TypeKVPairs, TypeCommaIntSlice: + TypeKVPairs, TypeCommaIntSlice, TypeHeader: _, _, err := d.getPrimitive(field, schema) if err != nil { return errwrap.Wrapf(fmt.Sprintf("error converting input %v for field %q: {{err}}", value, field), err) @@ -126,7 +129,7 @@ func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { switch schema.Type { case TypeBool, TypeInt, TypeMap, TypeDurationSecond, TypeString, TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, - TypeKVPairs, TypeCommaIntSlice: + TypeKVPairs, TypeCommaIntSlice, TypeHeader: return d.getPrimitive(k, schema) default: return nil, false, @@ -291,6 +294,103 @@ func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bo } return result, true, nil + case TypeHeader: + /* + + There are multiple ways a header could be provided: + + 1. As a map[string]interface{} that resolves to a map[string]string or map[string][]string, or a mix of both + because that's permitted for headers. + This mainly comes from the API. + + 2. As a string... + a. That contains JSON that originally was JSON, but then was base64 encoded. + b. That contains JSON, ex. `{"content-type":"text/json","accept":["encoding/json"]}`. + This mainly comes from the API and is used to save space while sending in the header. + + 3. As an array of strings that contains comma-delimited key-value pairs associated via a colon, + ex: `content-type:text/json`,`accept:encoding/json`. + This mainly comes from the CLI. + + We go through these sequentially below. + + */ + result := http.Header{} + + toHeader := func(resultMap map[string]interface{}) (http.Header, error) { + header := http.Header{} + for headerKey, headerValGroup := range resultMap { + switch typedHeader := headerValGroup.(type) { + case string: + header.Add(headerKey, typedHeader) + case []string: + for _, headerVal := range typedHeader { + header.Add(headerKey, headerVal) + } + case []interface{}: + for _, headerVal := range typedHeader { + strHeaderVal, ok := headerVal.(string) + if !ok { + // All header values should already be strings when they're being sent in. + // Even numbers and booleans will be treated as strings. + return nil, fmt.Errorf("received non-string value for header key:%s, val:%s", headerKey, headerValGroup) + } + header.Add(headerKey, strHeaderVal) + } + default: + return nil, fmt.Errorf("unrecognized type for %s", headerValGroup) + } + } + return header, nil + } + + resultMap := make(map[string]interface{}) + + // 1. Are we getting a map from the API? + if err := mapstructure.WeakDecode(raw, &resultMap); err == nil { + result, err = toHeader(resultMap) + if err != nil { + return nil, true, err + } + return result, true, nil + } + + // 2. Are we getting a JSON string? + if headerStr, ok := raw.(string); ok { + // a. Is it base64 encoded? + headerBytes, err := base64.StdEncoding.DecodeString(headerStr) + if err != nil { + // b. It's not base64 encoded, it's a straight-out JSON string. + headerBytes = []byte(headerStr) + } + if err := json.NewDecoder(bytes.NewReader(headerBytes)).Decode(&resultMap); err != nil { + return nil, true, err + } + result, err = toHeader(resultMap) + if err != nil { + return nil, true, err + } + return result, true, nil + } + + // 3. Are we getting an array of fields like "content-type:encoding/json" from the CLI? + var keyPairs []interface{} + if err := mapstructure.WeakDecode(raw, &keyPairs); err == nil { + for _, keyPairIfc := range keyPairs { + keyPair, ok := keyPairIfc.(string) + if !ok { + return nil, true, fmt.Errorf("invalid key pair %q", keyPair) + } + keyPairSlice := strings.SplitN(keyPair, ":", 2) + if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { + return nil, true, fmt.Errorf("invalid key pair %q", keyPair) + } + result.Add(keyPairSlice[0], keyPairSlice[1]) + } + return result, true, nil + } + return nil, true, fmt.Errorf("%s not provided an expected format", raw) + default: panic(fmt.Sprintf("Unknown type: %s", schema.Type)) } diff --git a/logical/framework/field_data_test.go b/logical/framework/field_data_test.go index 86ebdd13d..f4763af2a 100644 --- a/logical/framework/field_data_test.go +++ b/logical/framework/field_data_test.go @@ -1,6 +1,7 @@ package framework import ( + "net/http" "reflect" "testing" ) @@ -467,6 +468,87 @@ func TestFieldDataGet(t *testing.T) { }, }, + "type header, keypair string array": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{ + "foo": []interface{}{"key1:value1", "key2:value2", "key3:1"}, + }, + "foo", + http.Header{ + "Key1": []string{"value1"}, + "Key2": []string{"value2"}, + "Key3": []string{"1"}, + }, + }, + + "type header, b64 string": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{ + "foo": "eyJDb250ZW50LUxlbmd0aCI6IFsiNDMiXSwgIlVzZXItQWdlbnQiOiBbImF3cy1zZGstZ28vMS40LjEyIChnbzEuNy4xOyBsaW51eDsgYW1kNjQpIl0sICJYLVZhdWx0LUFXU0lBTS1TZXJ2ZXItSWQiOiBbInZhdWx0LmV4YW1wbGUuY29tIl0sICJYLUFtei1EYXRlIjogWyIyMDE2MDkzMFQwNDMxMjFaIl0sICJDb250ZW50LVR5cGUiOiBbImFwcGxpY2F0aW9uL3gtd3d3LWZvcm0tdXJsZW5jb2RlZDsgY2hhcnNldD11dGYtOCJdLCAiQXV0aG9yaXphdGlvbiI6IFsiQVdTNC1ITUFDLVNIQTI1NiBDcmVkZW50aWFsPWZvby8yMDE2MDkzMC91cy1lYXN0LTEvc3RzL2F3czRfcmVxdWVzdCwgU2lnbmVkSGVhZGVycz1jb250ZW50LWxlbmd0aDtjb250ZW50LXR5cGU7aG9zdDt4LWFtei1kYXRlO3gtdmF1bHQtc2VydmVyLCBTaWduYXR1cmU9YTY5ZmQ3NTBhMzQ0NWM0ZTU1M2UxYjNlNzlkM2RhOTBlZWY1NDA0N2YxZWI0ZWZlOGZmYmM5YzQyOGMyNjU1YiJdfQ==", + }, + "foo", + http.Header{ + "Content-Length": []string{"43"}, + "User-Agent": []string{"aws-sdk-go/1.4.12 (go1.7.1; linux; amd64)"}, + "X-Vault-Awsiam-Server-Id": []string{"vault.example.com"}, + "X-Amz-Date": []string{"20160930T043121Z"}, + "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, + "Authorization": []string{"AWS4-HMAC-SHA256 Credential=foo/20160930/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-vault-server, Signature=a69fd750a3445c4e553e1b3e79d3da90eef54047f1eb4efe8ffbc9c428c2655b"}, + }, + }, + + "type header, json string": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{ + "foo": `{"hello":"world","bonjour":["monde","dieu"]}`, + }, + "foo", + http.Header{ + "Hello": []string{"world"}, + "Bonjour": []string{"monde", "dieu"}, + }, + }, + + "type header, keypair string array with dupe key": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{ + "foo": []interface{}{"key1:value1", "key2:value2", "key3:1", "key3:true"}, + }, + "foo", + http.Header{ + "Key1": []string{"value1"}, + "Key2": []string{"value2"}, + "Key3": []string{"1", "true"}, + }, + }, + + "type header, map string slice": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{ + "foo": map[string][]string{ + "key1": {"value1"}, + "key2": {"value2"}, + "key3": {"1"}, + }, + }, + "foo", + http.Header{ + "Key1": []string{"value1"}, + "Key2": []string{"value2"}, + "Key3": []string{"1"}, + }, + }, + "name string type, not supplied": { map[string]*FieldSchema{ "foo": {Type: TypeNameString}, @@ -556,6 +638,15 @@ func TestFieldDataGet(t *testing.T) { "foo", map[string]string{}, }, + + "type header, not supplied": { + map[string]*FieldSchema{ + "foo": {Type: TypeHeader}, + }, + map[string]interface{}{}, + "foo", + http.Header{}, + }, } for name, tc := range cases { @@ -565,7 +656,7 @@ func TestFieldDataGet(t *testing.T) { } if err := data.Validate(); err != nil { - t.Fatalf("bad: %#v", err) + t.Fatalf("bad: %s", err) } actual := data.Get(tc.Key) diff --git a/logical/framework/field_type.go b/logical/framework/field_type.go index d447eabfb..64a6a56dc 100644 --- a/logical/framework/field_type.go +++ b/logical/framework/field_type.go @@ -42,6 +42,12 @@ const ( // TypeCommaIntSlice is a helper for TypeSlice that returns a sanitized // slice of Ints TypeCommaIntSlice + + // TypeHeader is a helper for sending request headers through to Vault. + // For instance, the AWS and AliCloud credential plugins both act as a + // benevolent MITM for a request, and the headers are sent through and + // parsed. + TypeHeader ) func (t FieldType) String() string { @@ -64,6 +70,8 @@ func (t FieldType) String() string { return "duration (sec)" case TypeSlice, TypeStringSlice, TypeCommaStringSlice, TypeCommaIntSlice: return "slice" + case TypeHeader: + return "header" default: return "unknown type" }