diff --git a/agent/config.go b/agent/config.go deleted file mode 100644 index 8cb088f03..000000000 --- a/agent/config.go +++ /dev/null @@ -1,100 +0,0 @@ -package agent - -import ( - "errors" - "fmt" - "strings" - "time" - - "github.com/hashicorp/consul/lib" -) - -var errInvalidHeaderFormat = errors.New("agent: invalid format of 'header' field") - -func FixupCheckType(raw interface{}) error { - rawMap, ok := raw.(map[string]interface{}) - if !ok { - return nil - } - - // See https://github.com/hashicorp/consul/pull/3557 why we need this - // and why we should get rid of it. In Consul 1.0 we also didn't map - // Args correctly, so we ended up exposing (and need to carry forward) - // ScriptArgs, see https://github.com/hashicorp/consul/issues/3587. - lib.TranslateKeys(rawMap, map[string]string{ - "args": "ScriptArgs", - "script_args": "ScriptArgs", - "deregister_critical_service_after": "DeregisterCriticalServiceAfter", - "docker_container_id": "DockerContainerID", - "tls_skip_verify": "TLSSkipVerify", - "service_id": "ServiceID", - }) - - parseDuration := func(v interface{}) (time.Duration, error) { - if v == nil { - return 0, nil - } - switch x := v.(type) { - case time.Duration: - return x, nil - case float64: - return time.Duration(x), nil - case string: - return time.ParseDuration(x) - default: - return 0, fmt.Errorf("invalid format") - } - } - - parseHeaderMap := func(v interface{}) (map[string][]string, error) { - if v == nil { - return nil, nil - } - vm, ok := v.(map[string]interface{}) - if !ok { - return nil, errInvalidHeaderFormat - } - m := map[string][]string{} - for k, vv := range vm { - vs, ok := vv.([]interface{}) - if !ok { - return nil, errInvalidHeaderFormat - } - for _, vs := range vs { - s, ok := vs.(string) - if !ok { - return nil, errInvalidHeaderFormat - } - m[k] = append(m[k], s) - } - } - return m, nil - } - - for k, v := range rawMap { - switch strings.ToLower(k) { - case "header": - h, err := parseHeaderMap(v) - if err != nil { - return fmt.Errorf("invalid %q: %s", k, err) - } - rawMap[k] = h - - case "ttl", "interval", "timeout", "deregistercriticalserviceafter": - d, err := parseDuration(v) - if err != nil { - return fmt.Errorf("invalid %q: %v", k, err) - } - rawMap[k] = d - } - } - return nil -} - -func ParseMetaPair(raw string) (string, string) { - pair := strings.SplitN(raw, ":", 2) - if len(pair) == 2 { - return pair[0], pair[1] - } - return pair[0], "" -} diff --git a/agent/discovery_chain_endpoint.go b/agent/discovery_chain_endpoint.go index 1df2e39f4..c6dddd64e 100644 --- a/agent/discovery_chain_endpoint.go +++ b/agent/discovery_chain_endpoint.go @@ -95,63 +95,6 @@ type discoveryChainReadRequest struct { OverrideConnectTimeout time.Duration } -func (t *discoveryChainReadRequest) UnmarshalJSON(data []byte) (err error) { - type Alias discoveryChainReadRequest - aux := &struct { - OverrideConnectTimeout interface{} - OverrideProtocol interface{} - OverrideMeshGateway *struct{ Mode interface{} } - - OverrideConnectTimeoutSnake interface{} `json:"override_connect_timeout"` - OverrideProtocolSnake interface{} `json:"override_protocol"` - OverrideMeshGatewaySnake *struct{ Mode interface{} } `json:"override_mesh_gateway"` - - *Alias - }{ - Alias: (*Alias)(t), - } - if err = lib.UnmarshalJSON(data, &aux); err != nil { - return err - } - - if aux.OverrideConnectTimeout == nil { - aux.OverrideConnectTimeout = aux.OverrideConnectTimeoutSnake - } - if aux.OverrideProtocol == nil { - aux.OverrideProtocol = aux.OverrideProtocolSnake - } - if aux.OverrideMeshGateway == nil { - aux.OverrideMeshGateway = aux.OverrideMeshGatewaySnake - } - - // weakly typed input - if aux.OverrideProtocol != nil { - switch v := aux.OverrideProtocol.(type) { - case string, float64, bool: - t.OverrideProtocol = fmt.Sprintf("%v", v) - default: - return fmt.Errorf("OverrideProtocol: invalid type %T", v) - } - } - if aux.OverrideMeshGateway != nil { - t.OverrideMeshGateway.Mode = structs.MeshGatewayMode(fmt.Sprintf("%v", aux.OverrideMeshGateway.Mode)) - } - - // duration - if aux.OverrideConnectTimeout != nil { - switch v := aux.OverrideConnectTimeout.(type) { - case string: - if t.OverrideConnectTimeout, err = time.ParseDuration(v); err != nil { - return err - } - case float64: - t.OverrideConnectTimeout = time.Duration(v) - } - } - - return nil -} - // discoveryChainReadResponse is the API variation of structs.DiscoveryChainResponse type discoveryChainReadResponse struct { Chain *structs.CompiledDiscoveryChain diff --git a/agent/http.go b/agent/http.go index 7645fc849..b4c6b64ac 100644 --- a/agent/http.go +++ b/agent/http.go @@ -1038,7 +1038,7 @@ func (s *HTTPServer) parseMetaFilter(req *http.Request) map[string]string { if filterList, ok := req.URL.Query()["node-meta"]; ok { filters := make(map[string]string) for _, filter := range filterList { - key, value := ParseMetaPair(filter) + key, value := parseMetaPair(filter) filters[key] = value } return filters @@ -1046,6 +1046,14 @@ func (s *HTTPServer) parseMetaFilter(req *http.Request) map[string]string { return nil } +func parseMetaPair(raw string) (string, string) { + pair := strings.SplitN(raw, ":", 2) + if len(pair) == 2 { + return pair[0], pair[1] + } + return pair[0], "" +} + // parseInternal is a convenience method for endpoints that need // to use both parseWait and parseDC. func (s *HTTPServer) parseInternal(resp http.ResponseWriter, req *http.Request, dc *string, b structs.QueryOptionsCompat) bool { diff --git a/agent/http_decode_test.go b/agent/http_decode_test.go index 2fc79334a..999361443 100644 --- a/agent/http_decode_test.go +++ b/agent/http_decode_test.go @@ -1984,283 +1984,6 @@ func TestDecodeCatalogRegister(t *testing.T) { } } -// discoveryChainReadRequest: -// OverrideMeshGateway structs.MeshGatewayConfig -// Mode structs.MeshGatewayMode // string -// OverrideProtocol string -// OverrideConnectTimeout time.Duration -func TestDecodeDiscoveryChainRead(t *testing.T) { - var weaklyTypedDurationTCs = []translateValueTestCase{ - { - desc: "positive string integer (weakly typed)", - durations: &durationTC{ - in: `"2000"`, - }, - wantErr: true, - }, - { - desc: "negative string integer (weakly typed)", - durations: &durationTC{ - in: `"-50"`, - }, - wantErr: true, - }, - } - - for _, tc := range append(durationTestCases, weaklyTypedDurationTCs...) { - t.Run(tc.desc, func(t *testing.T) { - // set up request body - jsonStr := fmt.Sprintf(`{ - "OverrideConnectTimeout": %s - }`, tc.durations.in) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - if err == nil && tc.wantErr { - t.Fatal("expected err, got nil") - } - if err != nil && !tc.wantErr { - t.Fatalf("expected nil error, got %v", err) - } - if out.OverrideConnectTimeout != tc.durations.want { - t.Fatalf("expected OverrideConnectTimeout to be %s, got %s", tc.durations.want, out.OverrideConnectTimeout) - } - }) - } - - // Other possibly weakly-typed inputs.. - var weaklyTypedStringTCs = []struct { - desc string - in, want string - wantErr bool - }{ - { - desc: "positive integer for string field (weakly typed)", - in: `200`, - want: "200", - }, - { - desc: "negative integer for string field (weakly typed)", - in: `-200`, - want: "-200", - }, - { - desc: "bool for string field (weakly typed)", - in: `true`, - want: "true", // previously: "1" - }, - { - desc: "float for string field (weakly typed)", - in: `1.2223`, - want: "1.2223", - }, - { - desc: "map for string field (weakly typed)", - in: `{}`, - wantErr: true, - }, - { - desc: "slice for string field (weakly typed)", - in: `[]`, - wantErr: true, // previously: want: "" - }, - } - - for _, tc := range weaklyTypedStringTCs { - t.Run(tc.desc, func(t *testing.T) { - // set up request body - jsonStr := fmt.Sprintf(`{ - "OverrideProtocol": %[1]s, - "OverrideMeshGateway": {"Mode": %[1]s} - }`, tc.in) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - - if err == nil && tc.wantErr { - t.Fatal("expected err, got nil") - } - if err != nil && !tc.wantErr { - t.Fatalf("expected nil error, got %v", err) - } - if out.OverrideProtocol != tc.want { - t.Fatalf("expected OverrideProtocol to be %s, got %s", tc.want, out.OverrideProtocol) - } - if out.OverrideMeshGateway.Mode != structs.MeshGatewayMode(tc.want) { - t.Fatalf("expected OverrideMeshGateway.Mode to be %s, got %s", tc.want, out.OverrideMeshGateway.Mode) - } - }) - } - - // translate field tcs - - overrideMeshGatewayFields := []string{ - `"OverrideMeshGateway": {"Mode": %s}`, - `"override_mesh_gateway": {"Mode": %s}`, - } - - overrideMeshGatewayEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideMeshGateway.Mode - if got != structs.MeshGatewayMode(want.(string)) { - return fmt.Errorf("expected OverrideMeshGateway to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideMeshGatewayTCs = []translateKeyTestCase{ - { - desc: "OverrideMeshGateway: both set", - in: []interface{}{`"one"`, `"two"`}, - want: "one", - jsonFmtStr: `{` + strings.Join(overrideMeshGatewayFields, ",") + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: first set", - in: []interface{}{`"one"`}, - want: "one", - jsonFmtStr: `{` + overrideMeshGatewayFields[0] + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: second set", - in: []interface{}{`"two"`}, - want: "two", - jsonFmtStr: `{` + overrideMeshGatewayFields[1] + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: neither set", - in: []interface{}{}, - want: "", // zero value - jsonFmtStr: `{}`, - equalityFn: overrideMeshGatewayEqFn, - }, - } - - overrideProtocolFields := []string{ - `"OverrideProtocol": %s`, - `"override_protocol": %s`, - } - - overrideProtocolEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideProtocol - if got != want { - return fmt.Errorf("expected OverrideProtocol to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideProtocolTCs = []translateKeyTestCase{ - { - desc: "OverrideProtocol: both set", - in: []interface{}{`"one"`, `"two"`}, - want: "one", - jsonFmtStr: `{` + strings.Join(overrideProtocolFields, ",") + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: first set", - in: []interface{}{`"one"`}, - want: "one", - jsonFmtStr: `{` + overrideProtocolFields[0] + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: second set", - in: []interface{}{`"two"`}, - want: "two", - jsonFmtStr: `{` + overrideProtocolFields[1] + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: neither set", - in: []interface{}{}, - want: "", // zero value - jsonFmtStr: `{}`, - equalityFn: overrideProtocolEqFn, - }, - } - - overrideConnectTimeoutFields := []string{ - `"OverrideConnectTimeout": %s`, - `"override_connect_timeout": %s`, - } - - overrideConnectTimeoutEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideConnectTimeout - if got != want { - return fmt.Errorf("expected OverrideConnectTimeout to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideConnectTimeoutTCs = []translateKeyTestCase{ - { - desc: "OverrideConnectTimeout: both set", - in: []interface{}{`"2h0m"`, `"3h0m"`}, - want: 2 * time.Hour, - jsonFmtStr: "{" + strings.Join(overrideConnectTimeoutFields, ",") + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: first set", - in: []interface{}{`"2h0m"`}, - want: 2 * time.Hour, - jsonFmtStr: "{" + overrideConnectTimeoutFields[0] + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: second set", - in: []interface{}{`"3h0m"`}, - want: 3 * time.Hour, - jsonFmtStr: "{" + overrideConnectTimeoutFields[1] + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: neither set", - in: []interface{}{}, - want: time.Duration(0), - jsonFmtStr: "{}", - equalityFn: overrideConnectTimeoutEqFn, - }, - } - - // lib.TranslateKeys(raw, map[string]string{ - // "override_mesh_gateway": "overridemeshgateway", - // "override_protocol": "overrideprotocol", - // "override_connect_timeout": "overrideconnecttimeout", - // }) - - translateFieldTCs := [][]translateKeyTestCase{ - translateOverrideMeshGatewayTCs, - translateOverrideProtocolTCs, - translateOverrideConnectTimeoutTCs, - } - - for _, tcGroup := range translateFieldTCs { - for _, tc := range tcGroup { - t.Run(tc.desc, func(t *testing.T) { - jsonStr := fmt.Sprintf(tc.jsonFmtStr, tc.in...) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - if err != nil { - t.Fatal(err) - } - - if err := tc.equalityFn(out, tc.want); err != nil { - t.Fatal(err) - } - }) - } - } - -} - // IntentionRequest: // Datacenter string // Op structs.IntentionOp