diff --git a/agent/agent.go b/agent/agent.go index 7a313cb4f..91d42cb73 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -361,9 +361,9 @@ type Agent struct { // run by the Agent routineManager *routine.Manager - // FileWatcher is the watcher responsible to report events when a config file + // configFileWatcher is the watcher responsible to report events when a config file // changed - FileWatcher config.Watcher + configFileWatcher config.Watcher // xdsServer serves the XDS protocol for configuring Envoy proxies. xdsServer *xds.Server @@ -462,6 +462,13 @@ func New(bd BaseDeps) (*Agent, error) { a.baseDeps.WatchedFiles = append(a.baseDeps.WatchedFiles, f.Cfg.CertFile) } } + if a.baseDeps.RuntimeConfig.AutoReloadConfig && len(a.baseDeps.WatchedFiles) > 0 { + w, err := config.NewRateLimitedFileWatcher(a.baseDeps.WatchedFiles, a.baseDeps.Logger, a.baseDeps.RuntimeConfig.AutoReloadConfigCoalesceInterval) + if err != nil { + return nil, err + } + a.configFileWatcher = w + } return &a, nil } @@ -713,25 +720,20 @@ func (a *Agent) Start(ctx context.Context) error { }) // start a go routine to reload config based on file watcher events - if a.baseDeps.RuntimeConfig.AutoReloadConfig && len(a.baseDeps.WatchedFiles) > 0 { - w, err := config.NewFileWatcher(a.baseDeps.WatchedFiles, a.baseDeps.Logger) - if err != nil { - a.baseDeps.Logger.Error("error loading config", "error", err) - } else { - a.FileWatcher = w - a.baseDeps.Logger.Debug("starting file watcher") - a.FileWatcher.Start(context.Background()) - go func() { - for event := range a.FileWatcher.EventsCh() { - a.baseDeps.Logger.Debug("auto-reload config triggered", "event-file", event.Filename) - err := a.AutoReloadConfig() - if err != nil { - a.baseDeps.Logger.Error("error loading config", "error", err) - } + if a.configFileWatcher != nil { + a.baseDeps.Logger.Debug("starting file watcher") + a.configFileWatcher.Start(context.Background()) + go func() { + for event := range a.configFileWatcher.EventsCh() { + a.baseDeps.Logger.Debug("auto-reload config triggered", "num-events", len(event.Filenames)) + err := a.AutoReloadConfig() + if err != nil { + a.baseDeps.Logger.Error("error loading config", "error", err) } - }() - } + } + }() } + return nil } @@ -1413,8 +1415,8 @@ func (a *Agent) ShutdownAgent() error { a.stopAllWatches() // Stop config file watcher - if a.FileWatcher != nil { - a.FileWatcher.Stop() + if a.configFileWatcher != nil { + a.configFileWatcher.Stop() } a.stopLicenseManager() @@ -3772,13 +3774,13 @@ func (a *Agent) reloadConfig(autoReload bool) error { {a.config.TLS.HTTPS, newCfg.TLS.HTTPS}, } { if f.oldCfg.KeyFile != f.newCfg.KeyFile { - a.FileWatcher.Replace(f.oldCfg.KeyFile, f.newCfg.KeyFile) + a.configFileWatcher.Replace(f.oldCfg.KeyFile, f.newCfg.KeyFile) if err != nil { return err } } if f.oldCfg.CertFile != f.newCfg.CertFile { - a.FileWatcher.Replace(f.oldCfg.CertFile, f.newCfg.CertFile) + a.configFileWatcher.Replace(f.oldCfg.CertFile, f.newCfg.CertFile) if err != nil { return err } diff --git a/agent/agent_test.go b/agent/agent_test.go index 43d9bd31d..ba82f127f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -5545,7 +5545,8 @@ func TestAgent_AutoReloadDoReload_WhenCertThenKeyUpdated(t *testing.T) { testrpc.WaitForTestAgent(t, srv.RPC, "dc1", testrpc.WithToken(TestDefaultInitialManagementToken)) - cert1 := srv.tlsConfigurator.Cert() + cert1Pub := srv.tlsConfigurator.Cert().Certificate + cert1Key := srv.tlsConfigurator.Cert().PrivateKey certNew, privateKeyNew, err := tlsutil.GenerateCert(tlsutil.CertOpts{ Signer: signer, @@ -5575,8 +5576,10 @@ func TestAgent_AutoReloadDoReload_WhenCertThenKeyUpdated(t *testing.T) { // cert should not change as we did not update the associated key time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.Equal(r, cert1.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.Equal(r, cert1.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.Equal(r, cert1Pub, cert.Certificate) + require.Equal(r, cert1Key, cert.PrivateKey) }) require.NoError(t, ioutil.WriteFile(keyFile, []byte(privateKeyNew), 0600)) @@ -5584,8 +5587,8 @@ func TestAgent_AutoReloadDoReload_WhenCertThenKeyUpdated(t *testing.T) { // cert should change as we did not update the associated key time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.NotEqual(r, cert1.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.NotEqual(r, cert1.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + require.NotEqual(r, cert1Pub, srv.tlsConfigurator.Cert().Certificate) + require.NotEqual(r, cert1Key, srv.tlsConfigurator.Cert().PrivateKey) }) } @@ -5647,11 +5650,13 @@ func TestAgent_AutoReloadDoReload_WhenKeyThenCertUpdated(t *testing.T) { `), 0600)) srv := StartTestAgent(t, TestAgent{Name: "TestAgent-Server", HCL: hclConfig, configFiles: []string{configFile}}) + defer srv.Shutdown() testrpc.WaitForTestAgent(t, srv.RPC, "dc1", testrpc.WithToken(TestDefaultInitialManagementToken)) - cert1 := srv.tlsConfigurator.Cert() + cert1Pub := srv.tlsConfigurator.Cert().Certificate + cert1Key := srv.tlsConfigurator.Cert().PrivateKey certNew, privateKeyNew, err := tlsutil.GenerateCert(tlsutil.CertOpts{ Signer: signer, @@ -5667,8 +5672,10 @@ func TestAgent_AutoReloadDoReload_WhenKeyThenCertUpdated(t *testing.T) { // cert should not change as we did not update the associated key time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.Equal(r, cert1.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.Equal(r, cert1.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.Equal(r, cert1Pub, cert.Certificate) + require.Equal(r, cert1Key, cert.PrivateKey) }) require.NoError(t, ioutil.WriteFile(certFileNew, []byte(certNew), 0600)) @@ -5689,10 +5696,13 @@ func TestAgent_AutoReloadDoReload_WhenKeyThenCertUpdated(t *testing.T) { // cert should change as we did not update the associated key time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.NotEqual(r, cert1.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.NotEqual(r, cert1.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.NotEqual(r, cert1Key, cert.Certificate) + require.NotEqual(r, cert1Key, cert.PrivateKey) }) - cert2 := srv.tlsConfigurator.Cert() + cert2Pub := srv.tlsConfigurator.Cert().Certificate + cert2Key := srv.tlsConfigurator.Cert().PrivateKey certNew2, privateKeyNew2, err := tlsutil.GenerateCert(tlsutil.CertOpts{ Signer: signer, @@ -5707,8 +5717,10 @@ func TestAgent_AutoReloadDoReload_WhenKeyThenCertUpdated(t *testing.T) { // cert should not change as we did not update the associated cert time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.Equal(r, cert2.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.Equal(r, cert2.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.Equal(r, cert2Pub, cert.Certificate) + require.Equal(r, cert2Key, cert.PrivateKey) }) require.NoError(t, ioutil.WriteFile(certFileNew, []byte(certNew2), 0600)) @@ -5716,7 +5728,120 @@ func TestAgent_AutoReloadDoReload_WhenKeyThenCertUpdated(t *testing.T) { // cert should change as we did update the associated key time.Sleep(1 * time.Second) retry.Run(t, func(r *retry.R) { - require.NotEqual(r, cert2.Certificate, srv.tlsConfigurator.Cert().Certificate) - require.NotEqual(r, cert2.PrivateKey, srv.tlsConfigurator.Cert().PrivateKey) + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.NotEqual(r, cert2Pub, cert.Certificate) + require.NotEqual(r, cert2Key, cert.PrivateKey) }) } + +func Test_coalesceTimerTwoPeriods(t *testing.T) { + + certsDir := testutil.TempDir(t, "auto-config") + + // write some test TLS certificates out to the cfg dir + serverName := "server.dc1.consul" + signer, _, err := tlsutil.GeneratePrivateKey() + require.NoError(t, err) + + ca, _, err := tlsutil.GenerateCA(tlsutil.CAOpts{Signer: signer}) + require.NoError(t, err) + + cert, privateKey, err := tlsutil.GenerateCert(tlsutil.CertOpts{ + Signer: signer, + CA: ca, + Name: "Test Cert Name", + Days: 365, + DNSNames: []string{serverName}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + }) + require.NoError(t, err) + + certFile := filepath.Join(certsDir, "cert.pem") + caFile := filepath.Join(certsDir, "cacert.pem") + keyFile := filepath.Join(certsDir, "key.pem") + + require.NoError(t, ioutil.WriteFile(certFile, []byte(cert), 0600)) + require.NoError(t, ioutil.WriteFile(caFile, []byte(ca), 0600)) + require.NoError(t, ioutil.WriteFile(keyFile, []byte(privateKey), 0600)) + + // generate a gossip key + gossipKey := make([]byte, 32) + n, err := rand.Read(gossipKey) + require.NoError(t, err) + require.Equal(t, 32, n) + gossipKeyEncoded := base64.StdEncoding.EncodeToString(gossipKey) + + hclConfig := TestACLConfigWithParams(nil) + + configFile := testutil.TempDir(t, "config") + "/config.hcl" + require.NoError(t, ioutil.WriteFile(configFile, []byte(` + encrypt = "`+gossipKeyEncoded+`" + encrypt_verify_incoming = true + encrypt_verify_outgoing = true + verify_incoming = true + verify_outgoing = true + verify_server_hostname = true + ca_file = "`+caFile+`" + cert_file = "`+certFile+`" + key_file = "`+keyFile+`" + connect { enabled = true } + auto_reload_config = true + `), 0600)) + + coalesceInterval := 100 * time.Millisecond + testAgent := TestAgent{Name: "TestAgent-Server", HCL: hclConfig, configFiles: []string{configFile}, Config: &config.RuntimeConfig{ + AutoReloadConfigCoalesceInterval: coalesceInterval, + }} + srv := StartTestAgent(t, testAgent) + defer srv.Shutdown() + + testrpc.WaitForTestAgent(t, srv.RPC, "dc1", testrpc.WithToken(TestDefaultInitialManagementToken)) + + cert1Pub := srv.tlsConfigurator.Cert().Certificate + cert1Key := srv.tlsConfigurator.Cert().PrivateKey + + certNew, privateKeyNew, err := tlsutil.GenerateCert(tlsutil.CertOpts{ + Signer: signer, + CA: ca, + Name: "Test Cert Name", + Days: 365, + DNSNames: []string{serverName}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + }) + require.NoError(t, err) + certFileNew := filepath.Join(certsDir, "cert_new.pem") + require.NoError(t, ioutil.WriteFile(certFileNew, []byte(certNew), 0600)) + require.NoError(t, ioutil.WriteFile(configFile, []byte(` + encrypt = "`+gossipKeyEncoded+`" + encrypt_verify_incoming = true + encrypt_verify_outgoing = true + verify_incoming = true + verify_outgoing = true + verify_server_hostname = true + ca_file = "`+caFile+`" + cert_file = "`+certFileNew+`" + key_file = "`+keyFile+`" + connect { enabled = true } + auto_reload_config = true + `), 0600)) + + // cert should not change as we did not update the associated key + time.Sleep(coalesceInterval * 2) + retry.Run(t, func(r *retry.R) { + cert := srv.tlsConfigurator.Cert() + require.NotNil(r, cert) + require.Equal(r, cert1Pub, cert.Certificate) + require.Equal(r, cert1Key, cert.PrivateKey) + }) + + require.NoError(t, ioutil.WriteFile(keyFile, []byte(privateKeyNew), 0600)) + + // cert should change as we did not update the associated key + time.Sleep(coalesceInterval * 2) + retry.Run(t, func(r *retry.R) { + require.NotEqual(r, cert1Pub, srv.tlsConfigurator.Cert().Certificate) + require.NotEqual(r, cert1Key, srv.tlsConfigurator.Cert().PrivateKey) + }) + +} diff --git a/agent/config/builder.go b/agent/config/builder.go index b3bdc46f3..d8a4ac042 100644 --- a/agent/config/builder.go +++ b/agent/config/builder.go @@ -1004,64 +1004,65 @@ func (b *builder) build() (rt RuntimeConfig, err error) { LogRotateBytes: intVal(c.LogRotateBytes), LogRotateMaxFiles: intVal(c.LogRotateMaxFiles), }, - MaxQueryTime: b.durationVal("max_query_time", c.MaxQueryTime), - NodeID: types.NodeID(stringVal(c.NodeID)), - NodeMeta: c.NodeMeta, - NodeName: b.nodeName(c.NodeName), - ReadReplica: boolVal(c.ReadReplica), - PidFile: stringVal(c.PidFile), - PrimaryDatacenter: primaryDatacenter, - PrimaryGateways: b.expandAllOptionalAddrs("primary_gateways", c.PrimaryGateways), - PrimaryGatewaysInterval: b.durationVal("primary_gateways_interval", c.PrimaryGatewaysInterval), - RPCAdvertiseAddr: rpcAdvertiseAddr, - RPCBindAddr: rpcBindAddr, - RPCHandshakeTimeout: b.durationVal("limits.rpc_handshake_timeout", c.Limits.RPCHandshakeTimeout), - RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout), - RPCMaxBurst: intVal(c.Limits.RPCMaxBurst), - RPCMaxConnsPerClient: intVal(c.Limits.RPCMaxConnsPerClient), - RPCProtocol: intVal(c.RPCProtocol), - RPCRateLimit: rate.Limit(float64Val(c.Limits.RPCRate)), - RPCConfig: consul.RPCConfig{EnableStreaming: boolValWithDefault(c.RPC.EnableStreaming, serverMode)}, - RaftProtocol: intVal(c.RaftProtocol), - RaftSnapshotThreshold: intVal(c.RaftSnapshotThreshold), - RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval), - RaftTrailingLogs: intVal(c.RaftTrailingLogs), - ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN), - ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN), - RejoinAfterLeave: boolVal(c.RejoinAfterLeave), - RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN), - RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN), - RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN), - RetryJoinMaxAttemptsLAN: intVal(c.RetryJoinMaxAttemptsLAN), - RetryJoinMaxAttemptsWAN: intVal(c.RetryJoinMaxAttemptsWAN), - RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN), - SegmentName: stringVal(c.SegmentName), - Segments: segments, - SegmentLimit: intVal(c.SegmentLimit), - SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN, - SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN, - SerfAllowedCIDRsLAN: serfAllowedCIDRSLAN, - SerfAllowedCIDRsWAN: serfAllowedCIDRSWAN, - SerfBindAddrLAN: serfBindAddrLAN, - SerfBindAddrWAN: serfBindAddrWAN, - SerfPortLAN: serfPortLAN, - SerfPortWAN: serfPortWAN, - ServerMode: serverMode, - ServerName: stringVal(c.ServerName), - ServerPort: serverPort, - Services: services, - SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin), - SkipLeaveOnInt: skipLeaveOnInt, - StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN), - StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN), - TaggedAddresses: c.TaggedAddresses, - TranslateWANAddrs: boolVal(c.TranslateWANAddrs), - TxnMaxReqLen: uint64Val(c.Limits.TxnMaxReqLen), - UIConfig: b.uiConfigVal(c.UIConfig), - UnixSocketGroup: stringVal(c.UnixSocket.Group), - UnixSocketMode: stringVal(c.UnixSocket.Mode), - UnixSocketUser: stringVal(c.UnixSocket.User), - Watches: c.Watches, + MaxQueryTime: b.durationVal("max_query_time", c.MaxQueryTime), + NodeID: types.NodeID(stringVal(c.NodeID)), + NodeMeta: c.NodeMeta, + NodeName: b.nodeName(c.NodeName), + ReadReplica: boolVal(c.ReadReplica), + PidFile: stringVal(c.PidFile), + PrimaryDatacenter: primaryDatacenter, + PrimaryGateways: b.expandAllOptionalAddrs("primary_gateways", c.PrimaryGateways), + PrimaryGatewaysInterval: b.durationVal("primary_gateways_interval", c.PrimaryGatewaysInterval), + RPCAdvertiseAddr: rpcAdvertiseAddr, + RPCBindAddr: rpcBindAddr, + RPCHandshakeTimeout: b.durationVal("limits.rpc_handshake_timeout", c.Limits.RPCHandshakeTimeout), + RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout), + RPCMaxBurst: intVal(c.Limits.RPCMaxBurst), + RPCMaxConnsPerClient: intVal(c.Limits.RPCMaxConnsPerClient), + RPCProtocol: intVal(c.RPCProtocol), + RPCRateLimit: rate.Limit(float64Val(c.Limits.RPCRate)), + RPCConfig: consul.RPCConfig{EnableStreaming: boolValWithDefault(c.RPC.EnableStreaming, serverMode)}, + RaftProtocol: intVal(c.RaftProtocol), + RaftSnapshotThreshold: intVal(c.RaftSnapshotThreshold), + RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval), + RaftTrailingLogs: intVal(c.RaftTrailingLogs), + ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN), + ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN), + RejoinAfterLeave: boolVal(c.RejoinAfterLeave), + RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN), + RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN), + RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN), + RetryJoinMaxAttemptsLAN: intVal(c.RetryJoinMaxAttemptsLAN), + RetryJoinMaxAttemptsWAN: intVal(c.RetryJoinMaxAttemptsWAN), + RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN), + SegmentName: stringVal(c.SegmentName), + Segments: segments, + SegmentLimit: intVal(c.SegmentLimit), + SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN, + SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN, + SerfAllowedCIDRsLAN: serfAllowedCIDRSLAN, + SerfAllowedCIDRsWAN: serfAllowedCIDRSWAN, + SerfBindAddrLAN: serfBindAddrLAN, + SerfBindAddrWAN: serfBindAddrWAN, + SerfPortLAN: serfPortLAN, + SerfPortWAN: serfPortWAN, + ServerMode: serverMode, + ServerName: stringVal(c.ServerName), + ServerPort: serverPort, + Services: services, + SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin), + SkipLeaveOnInt: skipLeaveOnInt, + StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN), + StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN), + TaggedAddresses: c.TaggedAddresses, + TranslateWANAddrs: boolVal(c.TranslateWANAddrs), + TxnMaxReqLen: uint64Val(c.Limits.TxnMaxReqLen), + UIConfig: b.uiConfigVal(c.UIConfig), + UnixSocketGroup: stringVal(c.UnixSocket.Group), + UnixSocketMode: stringVal(c.UnixSocket.Mode), + UnixSocketUser: stringVal(c.UnixSocket.User), + Watches: c.Watches, + AutoReloadConfigCoalesceInterval: 1 * time.Second, } rt.TLS, err = b.buildTLSConfig(rt, c.TLS) diff --git a/agent/config/file_watcher.go b/agent/config/file_watcher.go index d85abca4b..d62d19035 100644 --- a/agent/config/file_watcher.go +++ b/agent/config/file_watcher.go @@ -44,7 +44,7 @@ type watchedFile struct { } type FileWatcherEvent struct { - Filename string + Filenames []string } //NewFileWatcher create a file watcher that will watch all the files/folders from configFiles @@ -213,7 +213,7 @@ func (w *fileWatcher) handleEvent(ctx context.Context, event fsnotify.Event) err if isCreateEvent(event) || isWriteEvent(event) || isRenameEvent(event) { w.logger.Trace("call the handler", "filename", event.Name, "OP", event.Op) select { - case w.eventsCh <- &FileWatcherEvent{Filename: filename}: + case w.eventsCh <- &FileWatcherEvent{Filenames: []string{filename}}: case <-ctx.Done(): return ctx.Err() } @@ -265,7 +265,7 @@ func (w *fileWatcher) reconcile(ctx context.Context) { w.logger.Trace("call the handler", "filename", filename, "old modTime", configFile.modTime, "new modTime", newModTime) configFile.modTime = newModTime select { - case w.eventsCh <- &FileWatcherEvent{Filename: filename}: + case w.eventsCh <- &FileWatcherEvent{Filenames: []string{filename}}: case <-ctx.Done(): return } diff --git a/agent/config/file_watcher_test.go b/agent/config/file_watcher_test.go index 064729c53..52abb1328 100644 --- a/agent/config/file_watcher_test.go +++ b/agent/config/file_watcher_test.go @@ -64,6 +64,23 @@ func TestWatcherAddRemove(t *testing.T) { } +func TestWatcherReplace(t *testing.T) { + var filepaths []string + wi, err := NewFileWatcher(filepaths, hclog.New(&hclog.LoggerOptions{})) + w := wi.(*fileWatcher) + require.NoError(t, err) + file1 := createTempConfigFile(t, "temp_config1") + err = w.Add(file1) + require.NoError(t, err) + file2 := createTempConfigFile(t, "temp_config2") + err = w.Replace(file1, file2) + require.NoError(t, err) + _, ok := w.configFiles[file1] + require.False(t, ok) + _, ok = w.configFiles[file2] + require.True(t, ok) +} + func TestWatcherAddWhileRunning(t *testing.T) { var filepaths []string wi, err := NewFileWatcher(filepaths, hclog.New(&hclog.LoggerOptions{})) @@ -364,8 +381,8 @@ func TestEventWatcherMoveSoftLink(t *testing.T) { func assertEvent(name string, watcherCh chan *FileWatcherEvent, timeout time.Duration) error { select { case ev := <-watcherCh: - if ev.Filename != name && !strings.Contains(ev.Filename, name) { - return fmt.Errorf("filename do not match %s %s", ev.Filename, name) + if ev.Filenames[0] != name && !strings.Contains(ev.Filenames[0], name) { + return fmt.Errorf("filename do not match %s %s", ev.Filenames[0], name) } return nil case <-time.After(timeout): diff --git a/agent/config/ratelimited_file_watcher.go b/agent/config/ratelimited_file_watcher.go new file mode 100644 index 000000000..a47f9733d --- /dev/null +++ b/agent/config/ratelimited_file_watcher.go @@ -0,0 +1,90 @@ +package config + +import ( + "context" + "time" + + "github.com/hashicorp/go-hclog" +) + +type rateLimitedFileWatcher struct { + watcher Watcher + eventCh chan *FileWatcherEvent + coalesceInterval time.Duration +} + +func (r *rateLimitedFileWatcher) Start(ctx context.Context) { + r.watcher.Start(ctx) + r.coalesceTimer(ctx, r.watcher.EventsCh(), r.coalesceInterval) +} + +func (r rateLimitedFileWatcher) Stop() error { + return r.watcher.Stop() +} + +func (r rateLimitedFileWatcher) Add(filename string) error { + return r.watcher.Add(filename) +} + +func (r rateLimitedFileWatcher) Remove(filename string) { + r.watcher.Remove(filename) +} + +func (r rateLimitedFileWatcher) Replace(oldFile, newFile string) error { + return r.watcher.Replace(oldFile, newFile) +} + +func (r rateLimitedFileWatcher) EventsCh() chan *FileWatcherEvent { + return r.eventCh +} + +func NewRateLimitedFileWatcher(configFiles []string, logger hclog.Logger, coalesceInterval time.Duration) (Watcher, error) { + + watcher, err := NewFileWatcher(configFiles, logger) + if err != nil { + return nil, err + } + return &rateLimitedFileWatcher{ + watcher: watcher, + coalesceInterval: coalesceInterval, + eventCh: make(chan *FileWatcherEvent), + }, nil +} + +func (r rateLimitedFileWatcher) coalesceTimer(ctx context.Context, inputCh chan *FileWatcherEvent, coalesceDuration time.Duration) { + var ( + coalesceTimer *time.Timer + sendCh = make(chan struct{}) + fileWatcherEvents []string + ) + + go func() { + for { + select { + case event, ok := <-inputCh: + if !ok { + if len(fileWatcherEvents) > 0 { + r.eventCh <- &FileWatcherEvent{Filenames: fileWatcherEvents} + } + close(r.eventCh) + return + } + fileWatcherEvents = append(fileWatcherEvents, event.Filenames...) + if coalesceTimer == nil { + coalesceTimer = time.AfterFunc(coalesceDuration, func() { + // This runs in another goroutine so we can't just do the send + // directly here as access to fileWatcherEvents is racy. Instead, + // signal the main loop above. + sendCh <- struct{}{} + }) + } + case <-sendCh: + coalesceTimer = nil + r.eventCh <- &FileWatcherEvent{Filenames: fileWatcherEvents} + fileWatcherEvents = make([]string, 0) + case <-ctx.Done(): + return + } + } + }() +} diff --git a/agent/config/ratelimited_file_watcher_test.go b/agent/config/ratelimited_file_watcher_test.go new file mode 100644 index 000000000..ee1ecb8bb --- /dev/null +++ b/agent/config/ratelimited_file_watcher_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "context" + "os" + "testing" + "time" + + "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul/sdk/testutil" + "github.com/stretchr/testify/require" +) + +func TestNewRateLimitedWatcher(t *testing.T) { + w, err := NewRateLimitedFileWatcher([]string{}, hclog.New(&hclog.LoggerOptions{}), 1*time.Nanosecond) + require.NoError(t, err) + require.NotNil(t, w) +} + +func TestRateLimitedWatcherRenameEvent(t *testing.T) { + + fileTmp := createTempConfigFile(t, "temp_config3") + filepaths := []string{createTempConfigFile(t, "temp_config1"), createTempConfigFile(t, "temp_config2")} + w, err := NewRateLimitedFileWatcher(filepaths, hclog.New(&hclog.LoggerOptions{}), 1*time.Nanosecond) + + require.NoError(t, err) + w.Start(context.Background()) + defer func() { + _ = w.Stop() + }() + + require.NoError(t, err) + err = os.Rename(fileTmp, filepaths[0]) + time.Sleep(timeoutDuration + 50*time.Millisecond) + require.NoError(t, err) + require.NoError(t, assertEvent(filepaths[0], w.EventsCh(), defaultTimeout)) + // make sure we consume all events + _ = assertEvent(filepaths[0], w.EventsCh(), defaultTimeout) +} + +func TestRateLimitedWatcherAddNotExist(t *testing.T) { + + file := testutil.TempFile(t, "temp_config") + filename := file.Name() + randomStr(16) + w, err := NewRateLimitedFileWatcher([]string{filename}, hclog.New(&hclog.LoggerOptions{}), 1*time.Nanosecond) + require.Error(t, err, "no such file or directory") + require.Nil(t, w) +} + +func TestEventRateLimitedWatcherWrite(t *testing.T) { + + file := testutil.TempFile(t, "temp_config") + _, err := file.WriteString("test config") + require.NoError(t, err) + err = file.Sync() + require.NoError(t, err) + w, err := NewRateLimitedFileWatcher([]string{file.Name()}, hclog.New(&hclog.LoggerOptions{}), 1*time.Nanosecond) + require.NoError(t, err) + w.Start(context.Background()) + defer func() { + _ = w.Stop() + }() + + _, err = file.WriteString("test config 2") + require.NoError(t, err) + err = file.Sync() + require.NoError(t, err) + require.NoError(t, assertEvent(file.Name(), w.EventsCh(), defaultTimeout)) +} + +func TestEventRateLimitedWatcherMove(t *testing.T) { + + filepath := createTempConfigFile(t, "temp_config1") + + w, err := NewRateLimitedFileWatcher([]string{filepath}, hclog.New(&hclog.LoggerOptions{}), 1*time.Second) + require.NoError(t, err) + w.Start(context.Background()) + defer func() { + _ = w.Stop() + }() + + for i := 0; i < 10; i++ { + filepath2 := createTempConfigFile(t, "temp_config2") + err = os.Rename(filepath2, filepath) + time.Sleep(timeoutDuration + 50*time.Millisecond) + require.NoError(t, err) + } + require.NoError(t, assertEvent(filepath, w.EventsCh(), defaultTimeout)) + require.Error(t, assertEvent(filepath, w.EventsCh(), defaultTimeout), "expected timeout error") +} diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 99c51f335..442393ba1 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -1399,6 +1399,9 @@ type RuntimeConfig struct { // Watches []map[string]interface{} + // AutoReloadConfigCoalesceInterval Coalesce Interval for auto reload config + AutoReloadConfigCoalesceInterval time.Duration + EnterpriseRuntimeConfig } diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 408241e40..eb9d03d2b 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -6390,7 +6390,8 @@ func TestLoad_FullConfig(t *testing.T) { "args": []interface{}{"dltjDJ2a", "flEa7C2d"}, }, }, - RaftBoltDBConfig: consul.RaftBoltDBConfig{NoFreelistSync: true}, + RaftBoltDBConfig: consul.RaftBoltDBConfig{NoFreelistSync: true}, + AutoReloadConfigCoalesceInterval: 1 * time.Second, } entFullRuntimeConfig(expected) diff --git a/agent/config/testdata/TestRuntimeConfig_Sanitize.golden b/agent/config/testdata/TestRuntimeConfig_Sanitize.golden index 5356761e4..4fafb520b 100644 --- a/agent/config/testdata/TestRuntimeConfig_Sanitize.golden +++ b/agent/config/testdata/TestRuntimeConfig_Sanitize.golden @@ -64,6 +64,7 @@ "AutoEncryptIPSAN": [], "AutoEncryptTLS": false, "AutoReloadConfig": false, + "AutoReloadConfigCoalesceInterval": "0s", "AutopilotCleanupDeadServers": false, "AutopilotDisableUpgradeMigration": false, "AutopilotLastContactThreshold": "0s", @@ -456,4 +457,4 @@ "Version": "", "VersionPrerelease": "", "Watches": [] -} +} \ No newline at end of file diff --git a/agent/testagent.go b/agent/testagent.go index 3910a78d9..11ca9a518 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -221,6 +221,9 @@ func (a *TestAgent) Start(t *testing.T) error { bd.MetricsHandler = metrics.NewInmemSink(1*time.Second, time.Minute) } + if a.Config != nil && bd.RuntimeConfig.AutoReloadConfigCoalesceInterval == 0 { + bd.RuntimeConfig.AutoReloadConfigCoalesceInterval = a.Config.AutoReloadConfigCoalesceInterval + } a.Config = bd.RuntimeConfig agent, err := New(bd)