Merge pull request #9249 from hashicorp/dnephin/config-tags

config: use fields to detect enterprise-only settings
This commit is contained in:
Daniel Nephin 2021-01-07 19:49:29 -05:00 committed by GitHub
commit d5bdc2f539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 168 deletions

View File

@ -7,7 +7,6 @@ import (
) )
func TestBuildAndValidate_HTTPMaxConnsPerClientExceedsRLimit(t *testing.T) { func TestBuildAndValidate_HTTPMaxConnsPerClientExceedsRLimit(t *testing.T) {
t.Parallel()
hcl := ` hcl := `
limits{ limits{
# We put a very high value to be sure to fail # We put a very high value to be sure to fail

View File

@ -116,14 +116,6 @@ func NewBuilder(opts BuilderOpts) (*Builder, error) {
return nil, fmt.Errorf("config: -config-format must be either 'hcl' or 'json'") return nil, fmt.Errorf("config: -config-format must be either 'hcl' or 'json'")
} }
newSource := func(name string, v interface{}) Source {
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
panic(err)
}
return FileSource{Name: name, Format: "json", Data: string(b)}
}
b := &Builder{ b := &Builder{
opts: opts, opts: opts,
Head: []Source{DefaultSource(), DefaultEnterpriseSource()}, Head: []Source{DefaultSource(), DefaultEnterpriseSource()},
@ -138,8 +130,8 @@ func NewBuilder(opts BuilderOpts) (*Builder, error) {
// we need to merge all slice values defined in flags before we // we need to merge all slice values defined in flags before we
// merge the config files since the flag values for slices are // merge the config files since the flag values for slices are
// otherwise appended instead of prepended. // otherwise appended instead of prepended.
slices, values := b.splitSlicesAndValues(opts.Config) slices, values := splitSlicesAndValues(opts.Config)
b.Head = append(b.Head, newSource("flags.slices", slices)) b.Head = append(b.Head, LiteralSource{Name: "flags.slices", Config: slices})
for _, path := range opts.ConfigFiles { for _, path := range opts.ConfigFiles {
sources, err := b.sourcesFromPath(path, opts.ConfigFormat) sources, err := b.sourcesFromPath(path, opts.ConfigFormat)
if err != nil { if err != nil {
@ -147,7 +139,7 @@ func NewBuilder(opts BuilderOpts) (*Builder, error) {
} }
b.Sources = append(b.Sources, sources...) b.Sources = append(b.Sources, sources...)
} }
b.Tail = append(b.Tail, newSource("flags.values", values)) b.Tail = append(b.Tail, LiteralSource{Name: "flags.values", Config: values})
for i, s := range opts.HCL { for i, s := range opts.HCL {
b.Tail = append(b.Tail, FileSource{ b.Tail = append(b.Tail, FileSource{
Name: fmt.Sprintf("flags-%d.hcl", i), Name: fmt.Sprintf("flags-%d.hcl", i),
@ -314,8 +306,9 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
return RuntimeConfig{}, fmt.Errorf("failed to parse %v: %s", s.Source(), unusedErr) return RuntimeConfig{}, fmt.Errorf("failed to parse %v: %s", s.Source(), unusedErr)
} }
// for now this is a soft failure that will cause warnings but not actual problems for _, err := range validateEnterpriseConfigKeys(&c2) {
b.validateEnterpriseConfigKeys(&c2, md.Keys) b.warn("%s", err)
}
// if we have a single 'check' or 'service' we need to add them to the // if we have a single 'check' or 'service' we need to add them to the
// list of checks and services first since we cannot merge them // list of checks and services first since we cannot merge them
@ -1484,7 +1477,7 @@ func addrsUnique(inuse map[string]string, name string, addrs []net.Addr) error {
// splitSlicesAndValues moves all slice values defined in c to 'slices' // splitSlicesAndValues moves all slice values defined in c to 'slices'
// and all other values to 'values'. // and all other values to 'values'.
func (b *Builder) splitSlicesAndValues(c Config) (slices, values Config) { func splitSlicesAndValues(c Config) (slices, values Config) {
v, t := reflect.ValueOf(c), reflect.TypeOf(c) v, t := reflect.ValueOf(c), reflect.TypeOf(c)
rs, rv := reflect.New(t), reflect.New(t) rs, rv := reflect.New(t), reflect.New(t)
@ -1843,10 +1836,20 @@ func (b *Builder) stringValWithDefault(v *string, defaultVal string) string {
return *v return *v
} }
// Deprecated: use the stringVal() function instead of this Builder method. This
// method is being left here temporarily as there are many callers. It will be
// removed in the future.
func (b *Builder) stringVal(v *string) string { func (b *Builder) stringVal(v *string) string {
return b.stringValWithDefault(v, "") return b.stringValWithDefault(v, "")
} }
func stringVal(v *string) string {
if v == nil {
return ""
}
return *v
}
func (b *Builder) float64ValWithDefault(v *float64, defaultVal float64) float64 { func (b *Builder) float64ValWithDefault(v *float64, defaultVal float64) float64 {
if v == nil { if v == nil {
return defaultVal return defaultVal

View File

@ -4,47 +4,55 @@ package config
import ( import (
"fmt" "fmt"
"github.com/hashicorp/go-multierror"
) )
var ( // validateEnterpriseConfig is a function to validate the enterprise specific
enterpriseConfigMap map[string]func(*Config) = map[string]func(c *Config){ // configuration items after Parsing but before merging into the overall
"non_voting_server": func(c *Config) { // configuration. The original intent is to use it to ensure that we warn
// to maintain existing compatibility we don't nullify the value // for enterprise configurations used in OSS.
}, func validateEnterpriseConfigKeys(config *Config) []error {
"read_replica": func(c *Config) { var result []error
// to maintain existing compatibility we don't nullify the value add := func(k string) {
}, result = append(result, enterpriseConfigKeyError{key: k})
"segment": func(c *Config) {
// to maintain existing compatibility we don't nullify the value
},
"segments": func(c *Config) {
// to maintain existing compatibility we don't nullify the value
},
"autopilot.redundancy_zone_tag": func(c *Config) {
// to maintain existing compatibility we don't nullify the value
},
"autopilot.upgrade_version_tag": func(c *Config) {
// to maintain existing compatibility we don't nullify the value
},
"autopilot.disable_upgrade_migration": func(c *Config) {
// to maintain existing compatibility we don't nullify the value
},
"dns_config.prefer_namespace": func(c *Config) {
c.DNS.PreferNamespace = nil
},
"acl.msp_disable_bootstrap": func(c *Config) {
c.ACL.MSPDisableBootstrap = nil
},
"acl.tokens.managed_service_provider": func(c *Config) {
c.ACL.Tokens.ManagedServiceProvider = nil
},
"audit": func(c *Config) {
c.Audit = nil
},
} }
)
if config.ReadReplica != nil {
add(`read_replica (or the deprecated non_voting_server)`)
}
if stringVal(config.SegmentName) != "" {
add("segment")
}
if len(config.Segments) > 0 {
add("segments")
}
if stringVal(config.Autopilot.RedundancyZoneTag) != "" {
add("autopilot.redundancy_zone_tag")
}
if stringVal(config.Autopilot.UpgradeVersionTag) != "" {
add("autopilot.upgrade_version_tag")
}
if config.Autopilot.DisableUpgradeMigration != nil {
add("autopilot.disable_upgrade_migration")
}
if config.DNS.PreferNamespace != nil {
add("dns_config.prefer_namespace")
config.DNS.PreferNamespace = nil
}
if config.ACL.MSPDisableBootstrap != nil {
add("acl.msp_disable_bootstrap")
config.ACL.MSPDisableBootstrap = nil
}
if len(config.ACL.Tokens.ManagedServiceProvider) > 0 {
add("acl.tokens.managed_service_provider")
config.ACL.Tokens.ManagedServiceProvider = nil
}
if config.Audit != nil {
add("audit")
config.Audit = nil
}
return result
}
type enterpriseConfigKeyError struct { type enterpriseConfigKeyError struct {
key string key string
@ -61,23 +69,3 @@ func (*Builder) BuildEnterpriseRuntimeConfig(_ *RuntimeConfig, _ *Config) error
func (*Builder) validateEnterpriseConfig(_ RuntimeConfig) error { func (*Builder) validateEnterpriseConfig(_ RuntimeConfig) error {
return nil return nil
} }
// validateEnterpriseConfig is a function to validate the enterprise specific
// configuration items after Parsing but before merging into the overall
// configuration. The original intent is to use it to ensure that we warn
// for enterprise configurations used in OSS.
func (b *Builder) validateEnterpriseConfigKeys(config *Config, keys []string) error {
var err error
for _, k := range keys {
if unset, ok := enterpriseConfigMap[k]; ok {
keyErr := enterpriseConfigKeyError{key: k}
b.warn(keyErr.Error())
err = multierror.Append(err, keyErr)
unset(config)
}
}
return err
}

View File

@ -5,15 +5,13 @@ package config
import ( import (
"testing" "testing"
"github.com/hashicorp/go-multierror"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) { func TestValidateEnterpriseConfigKeys(t *testing.T) {
// ensure that all the enterprise configurations // ensure that all the enterprise configurations
type testCase struct { type testCase struct {
config Config config Config
keys []string
badKeys []string badKeys []string
check func(t *testing.T, c *Config) check func(t *testing.T, c *Config)
} }
@ -22,34 +20,22 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
stringVal := "string" stringVal := "string"
cases := map[string]testCase{ cases := map[string]testCase{
"non_voting_server": {
config: Config{
ReadReplica: &boolVal,
},
keys: []string{"non_voting_server"},
badKeys: []string{"non_voting_server"},
},
"read_replica": { "read_replica": {
config: Config{ config: Config{
ReadReplica: &boolVal, ReadReplica: &boolVal,
}, },
keys: []string{"read_replica"}, badKeys: []string{"read_replica (or the deprecated non_voting_server)"},
badKeys: []string{"read_replica"},
}, },
"segment": { "segment": {
config: Config{ config: Config{
SegmentName: &stringVal, SegmentName: &stringVal,
}, },
keys: []string{"segment"},
badKeys: []string{"segment"}, badKeys: []string{"segment"},
}, },
"segments": { "segments": {
config: Config{ config: Config{
Segments: []Segment{ Segments: []Segment{{Name: &stringVal}},
{Name: &stringVal},
}, },
},
keys: []string{"segments"},
badKeys: []string{"segments"}, badKeys: []string{"segments"},
}, },
"autopilot.redundancy_zone_tag": { "autopilot.redundancy_zone_tag": {
@ -58,7 +44,6 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
RedundancyZoneTag: &stringVal, RedundancyZoneTag: &stringVal,
}, },
}, },
keys: []string{"autopilot.redundancy_zone_tag"},
badKeys: []string{"autopilot.redundancy_zone_tag"}, badKeys: []string{"autopilot.redundancy_zone_tag"},
}, },
"autopilot.upgrade_version_tag": { "autopilot.upgrade_version_tag": {
@ -67,25 +52,18 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
UpgradeVersionTag: &stringVal, UpgradeVersionTag: &stringVal,
}, },
}, },
keys: []string{"autopilot.upgrade_version_tag"},
badKeys: []string{"autopilot.upgrade_version_tag"}, badKeys: []string{"autopilot.upgrade_version_tag"},
}, },
"autopilot.disable_upgrade_migration": { "autopilot.disable_upgrade_migration": {
config: Config{ config: Config{
Autopilot: Autopilot{ Autopilot: Autopilot{DisableUpgradeMigration: &boolVal},
DisableUpgradeMigration: &boolVal,
}, },
},
keys: []string{"autopilot.disable_upgrade_migration"},
badKeys: []string{"autopilot.disable_upgrade_migration"}, badKeys: []string{"autopilot.disable_upgrade_migration"},
}, },
"dns_config.prefer_namespace": { "dns_config.prefer_namespace": {
config: Config{ config: Config{
DNS: DNS{ DNS: DNS{PreferNamespace: &boolVal},
PreferNamespace: &boolVal,
}, },
},
keys: []string{"dns_config.prefer_namespace"},
badKeys: []string{"dns_config.prefer_namespace"}, badKeys: []string{"dns_config.prefer_namespace"},
check: func(t *testing.T, c *Config) { check: func(t *testing.T, c *Config) {
require.Nil(t, c.DNS.PreferNamespace) require.Nil(t, c.DNS.PreferNamespace)
@ -93,11 +71,8 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
}, },
"acl.msp_disable_bootstrap": { "acl.msp_disable_bootstrap": {
config: Config{ config: Config{
ACL: ACL{ ACL: ACL{MSPDisableBootstrap: &boolVal},
MSPDisableBootstrap: &boolVal,
}, },
},
keys: []string{"acl.msp_disable_bootstrap"},
badKeys: []string{"acl.msp_disable_bootstrap"}, badKeys: []string{"acl.msp_disable_bootstrap"},
check: func(t *testing.T, c *Config) { check: func(t *testing.T, c *Config) {
require.Nil(t, c.ACL.MSPDisableBootstrap) require.Nil(t, c.ACL.MSPDisableBootstrap)
@ -116,7 +91,6 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
}, },
}, },
}, },
keys: []string{"acl.tokens.managed_service_provider"},
badKeys: []string{"acl.tokens.managed_service_provider"}, badKeys: []string{"acl.tokens.managed_service_provider"},
check: func(t *testing.T, c *Config) { check: func(t *testing.T, c *Config) {
require.Empty(t, c.ACL.Tokens.ManagedServiceProvider) require.Empty(t, c.ACL.Tokens.ManagedServiceProvider)
@ -127,40 +101,29 @@ func TestBuilder_validateEnterpriseConfigKeys(t *testing.T) {
config: Config{ config: Config{
ReadReplica: &boolVal, ReadReplica: &boolVal,
SegmentName: &stringVal, SegmentName: &stringVal,
ACL: ACL{Tokens: Tokens{AgentMaster: &stringVal}},
}, },
keys: []string{"non_voting_server", "read_replica", "segment", "acl.tokens.agent_master"}, badKeys: []string{"read_replica (or the deprecated non_voting_server)", "segment"},
badKeys: []string{"non_voting_server", "read_replica", "segment"},
}, },
} }
for name, tcase := range cases { for name, tcase := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
b := &Builder{} errs := validateEnterpriseConfigKeys(&tcase.config)
if len(tcase.badKeys) == 0 {
err := b.validateEnterpriseConfigKeys(&tcase.config, tcase.keys) require.Len(t, errs, 0)
if len(tcase.badKeys) > 0 { return
require.Error(t, err)
multiErr, ok := err.(*multierror.Error)
require.True(t, ok)
var badKeys []string
for _, e := range multiErr.Errors {
if keyErr, ok := e.(enterpriseConfigKeyError); ok {
badKeys = append(badKeys, keyErr.key)
require.Contains(t, b.Warnings, keyErr.Error())
}
} }
require.ElementsMatch(t, tcase.badKeys, badKeys) var expected []error
for _, k := range tcase.badKeys {
expected = append(expected, enterpriseConfigKeyError{key: k})
}
require.ElementsMatch(t, expected, errs)
if tcase.check != nil { if tcase.check != nil {
tcase.check(t, &tcase.config) tcase.check(t, &tcase.config)
} }
} else {
require.NoError(t, err)
}
}) })
} }
} }

View File

@ -12,16 +12,17 @@ var entTokenConfigSanitize = `"EnterpriseConfig": {},`
func entFullRuntimeConfig(rt *RuntimeConfig) {} func entFullRuntimeConfig(rt *RuntimeConfig) {}
var enterpriseReadReplicaWarnings []string = []string{enterpriseConfigKeyError{key: "read_replica"}.Error()} var enterpriseReadReplicaWarnings = []string{enterpriseConfigKeyError{key: "read_replica (or the deprecated non_voting_server)"}.Error()}
var enterpriseConfigKeyWarnings []string var enterpriseConfigKeyWarnings = []string{
enterpriseConfigKeyError{key: "read_replica (or the deprecated non_voting_server)"}.Error(),
func init() { enterpriseConfigKeyError{key: "segment"}.Error(),
for k := range enterpriseConfigMap { enterpriseConfigKeyError{key: "segments"}.Error(),
if k == "non_voting_server" { enterpriseConfigKeyError{key: "autopilot.redundancy_zone_tag"}.Error(),
// this is an alias for "read_replica" so we shouldn't see it in warnings enterpriseConfigKeyError{key: "autopilot.upgrade_version_tag"}.Error(),
continue enterpriseConfigKeyError{key: "autopilot.disable_upgrade_migration"}.Error(),
} enterpriseConfigKeyError{key: "dns_config.prefer_namespace"}.Error(),
enterpriseConfigKeyWarnings = append(enterpriseConfigKeyWarnings, enterpriseConfigKeyError{key: k}.Error()) enterpriseConfigKeyError{key: "acl.msp_disable_bootstrap"}.Error(),
} enterpriseConfigKeyError{key: "acl.tokens.managed_service_provider"}.Error(),
enterpriseConfigKeyError{key: "audit"}.Error(),
} }

View File

@ -19,6 +19,8 @@ import (
"time" "time"
"github.com/armon/go-metrics/prometheus" "github.com/armon/go-metrics/prometheus"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -35,10 +37,9 @@ import (
type configTest struct { type configTest struct {
desc string desc string
args []string args []string
pre, post func() pre func()
json, jsontail []string json, jsontail []string
hcl, hcltail []string hcl, hcltail []string
skipformat bool
privatev4 func() ([]*net.IPAddr, error) privatev4 func() ([]*net.IPAddr, error)
publicv6 func() ([]*net.IPAddr, error) publicv6 func() ([]*net.IPAddr, error)
patch func(rt *RuntimeConfig) patch func(rt *RuntimeConfig)
@ -4821,7 +4822,6 @@ func TestBuilder_BuildAndValidate_ConfigFlagsAndEdgecases(t *testing.T) {
} }
func testConfig(t *testing.T, tests []configTest, dataDir string) { func testConfig(t *testing.T, tests []configTest, dataDir string) {
t.Helper()
for _, tt := range tests { for _, tt := range tests {
for pass, format := range []string{"json", "hcl"} { for pass, format := range []string{"json", "hcl"} {
// clean data dir before every test // clean data dir before every test
@ -4837,22 +4837,15 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
// json and hcl sources need to be in sync // json and hcl sources need to be in sync
// to make sure we're generating the same config // to make sure we're generating the same config
if len(tt.json) != len(tt.hcl) && !tt.skipformat { if len(tt.json) != len(tt.hcl) {
t.Fatal(tt.desc, ": JSON and HCL test case out of sync") t.Fatal(tt.desc, ": JSON and HCL test case out of sync")
} }
// select the source
srcs, tails := tt.json, tt.jsontail srcs, tails := tt.json, tt.jsontail
if format == "hcl" { if format == "hcl" {
srcs, tails = tt.hcl, tt.hcltail srcs, tails = tt.hcl, tt.hcltail
} }
// If we're skipping a format and the current format is empty,
// then skip it!
if tt.skipformat && len(srcs) == 0 {
continue
}
// build the description // build the description
var desc []string var desc []string
if !flagsOnly { if !flagsOnly {
@ -4863,8 +4856,8 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
} }
t.Run(strings.Join(desc, ":"), func(t *testing.T) { t.Run(strings.Join(desc, ":"), func(t *testing.T) {
// first parse the flags
flags := BuilderOpts{} flags := BuilderOpts{}
fs := flag.NewFlagSet("", flag.ContinueOnError) fs := flag.NewFlagSet("", flag.ContinueOnError)
AddFlags(fs, &flags) AddFlags(fs, &flags)
err := fs.Parse(tt.args) err := fs.Parse(tt.args)
@ -4876,17 +4869,10 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
if tt.pre != nil { if tt.pre != nil {
tt.pre() tt.pre()
} }
defer func() {
if tt.post != nil {
tt.post()
}
}()
// Then create a builder with the flags. // Then create a builder with the flags.
b, err := NewBuilder(flags) b, err := NewBuilder(flags)
if err != nil { require.NoError(t, err)
t.Fatal("NewBuilder", err)
}
patchBuilderShims(b) patchBuilderShims(b)
if tt.hostname != nil { if tt.hostname != nil {
@ -4899,7 +4885,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
b.opts.getPublicIPv6 = tt.publicv6 b.opts.getPublicIPv6 = tt.publicv6
} }
// read the source fragements // read the source fragments
for i, data := range srcs { for i, data := range srcs {
b.Sources = append(b.Sources, FileSource{ b.Sources = append(b.Sources, FileSource{
Name: fmt.Sprintf("src-%d.%s", i, format), Name: fmt.Sprintf("src-%d.%s", i, format),
@ -4915,7 +4901,6 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
}) })
} }
// build/merge the config fragments
actual, err := b.BuildAndValidate() actual, err := b.BuildAndValidate()
if err == nil && tt.err != "" { if err == nil && tt.err != "" {
t.Fatalf("got no error want %q", tt.err) t.Fatalf("got no error want %q", tt.err)
@ -4943,9 +4928,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
} }
patchBuilderShims(x) patchBuilderShims(x)
expected, err := x.Build() expected, err := x.Build()
if err != nil { require.NoError(t, err)
t.Fatalf("build default failed: %s", err)
}
if tt.patch != nil { if tt.patch != nil {
tt.patch(&expected) tt.patch(&expected)
} }
@ -4959,12 +4942,19 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
if tt.patchActual != nil { if tt.patchActual != nil {
tt.patchActual(&actual) tt.patchActual(&actual)
} }
require.Equal(t, expected, actual) assertDeepEqual(t, expected, actual, cmpopts.EquateEmpty())
}) })
} }
} }
} }
func assertDeepEqual(t *testing.T, x, y interface{}, opts ...cmp.Option) {
t.Helper()
if diff := cmp.Diff(x, y, opts...); diff != "" {
t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff)
}
}
func TestNewBuilder_InvalidConfigFormat(t *testing.T) { func TestNewBuilder_InvalidConfigFormat(t *testing.T) {
_, err := NewBuilder(BuilderOpts{ConfigFormat: "yaml"}) _, err := NewBuilder(BuilderOpts{ConfigFormat: "yaml"})
require.Error(t, err) require.Error(t, err)
@ -7396,7 +7386,6 @@ func TestNonZero(t *testing.T) {
} }
func TestConfigDecodeBytes(t *testing.T) { func TestConfigDecodeBytes(t *testing.T) {
t.Parallel()
// Test with some input // Test with some input
src := []byte("abc") src := []byte("abc")
key := base64.StdEncoding.EncodeToString(src) key := base64.StdEncoding.EncodeToString(src)