diff --git a/changelog/18863.txt b/changelog/18863.txt new file mode 100644 index 000000000..c1f2800c2 --- /dev/null +++ b/changelog/18863.txt @@ -0,0 +1,3 @@ +```release-note:improvement +agent: JWT auto-auth has a new config option, `remove_jwt_follows_symlinks` (default: false), that, if set to true will now remove the JWT, instead of the symlink to the JWT, if a symlink to a JWT has been provided in the `path` option, and the `remove_jwt_after_reading` config option is set to true (default). +``` \ No newline at end of file diff --git a/command/agent/auth/jwt/jwt.go b/command/agent/auth/jwt/jwt.go index 8f088eb19..ff96a32e0 100644 --- a/command/agent/auth/jwt/jwt.go +++ b/command/agent/auth/jwt/jwt.go @@ -7,6 +7,7 @@ import ( "io/fs" "net/http" "os" + "path/filepath" "sync" "sync/atomic" "time" @@ -18,19 +19,20 @@ import ( ) type jwtMethod struct { - logger hclog.Logger - path string - mountPath string - role string - removeJWTAfterReading bool - credsFound chan struct{} - watchCh chan string - stopCh chan struct{} - doneCh chan struct{} - credSuccessGate chan struct{} - ticker *time.Ticker - once *sync.Once - latestToken *atomic.Value + logger hclog.Logger + path string + mountPath string + role string + removeJWTAfterReading bool + removeJWTFollowsSymlinks bool + credsFound chan struct{} + watchCh chan string + stopCh chan struct{} + doneCh chan struct{} + credSuccessGate chan struct{} + ticker *time.Ticker + once *sync.Once + latestToken *atomic.Value } // NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod @@ -83,6 +85,14 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { j.removeJWTAfterReading = removeJWTAfterReading } + if removeJWTFollowsSymlinksRaw, ok := conf.Config["remove_jwt_follows_symlinks"]; ok { + removeJWTFollowsSymlinks, err := parseutil.ParseBool(removeJWTFollowsSymlinksRaw) + if err != nil { + return nil, fmt.Errorf("error parsing 'remove_jwt_follows_symlinks' value: %w", err) + } + j.removeJWTFollowsSymlinks = removeJWTFollowsSymlinks + } + switch { case j.path == "": return nil, errors.New("'path' value is empty") @@ -90,13 +100,24 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return nil, errors.New("'role' value is empty") } - // If we don't delete the JWT after reading, use a slower reload period, - // otherwise we would re-read the whole file every 500ms, instead of just - // doing a stat on the file every 500ms. + // Default readPeriod readPeriod := 1 * time.Minute - if j.removeJWTAfterReading { - readPeriod = 500 * time.Millisecond + + if jwtReadPeriodRaw, ok := conf.Config["jwt_read_period"]; ok { + jwtReadPeriod, err := parseutil.ParseDurationSecond(jwtReadPeriodRaw) + if err != nil { + return nil, fmt.Errorf("error parsing 'jwt_read_period' value: %w", err) + } + readPeriod = jwtReadPeriod + } else { + // If we don't delete the JWT after reading, use a slower reload period, + // otherwise we would re-read the whole file every 500ms, instead of just + // doing a stat on the file every 500ms. + if j.removeJWTAfterReading { + readPeriod = 500 * time.Millisecond + } } + j.ticker = time.NewTicker(readPeriod) go j.runWatcher() @@ -147,8 +168,8 @@ func (j *jwtMethod) runWatcher() { case <-j.credSuccessGate: // We only start the next loop once we're initially successful, - // since at startup Authenticate will be called and we don't want - // to end up immediately reauthenticating by having found a new + // since at startup Authenticate will be called, and we don't want + // to end up immediately re-authenticating by having found a new // value } @@ -182,11 +203,27 @@ func (j *jwtMethod) ingressToken() { // Check that the path refers to a file. // If it's a symlink, it could still be a symlink to a directory, // but os.ReadFile below will return a descriptive error. + evalSymlinkPath := j.path switch mode := fi.Mode(); { case mode.IsRegular(): // regular file case mode&fs.ModeSymlink != 0: - // symlink + // If our file path is a symlink, we should also return early (like above) without error + // if the file that is linked to is not present, otherwise we will error when trying + // to read that file by following the link in the os.ReadFile call. + evalSymlinkPath, err = filepath.EvalSymlinks(j.path) + if err != nil { + j.logger.Error("error encountered evaluating symlinks", "error", err) + return + } + _, err := os.Stat(evalSymlinkPath) + if err != nil { + if os.IsNotExist(err) { + return + } + j.logger.Error("error encountered stat'ing jwt file after evaluating symlinks", "error", err) + return + } default: j.logger.Error("jwt file is not a regular file or symlink") return @@ -207,7 +244,13 @@ func (j *jwtMethod) ingressToken() { } if j.removeJWTAfterReading { - if err := os.Remove(j.path); err != nil { + pathToRemove := j.path + if j.removeJWTFollowsSymlinks { + // If removeJWTFollowsSymlinks is set, we follow the symlink and delete the jwt, + // not just the symlink that links to the jwt + pathToRemove = evalSymlinkPath + } + if err := os.Remove(pathToRemove); err != nil { j.logger.Error("error removing jwt file", "error", err) } } diff --git a/command/agent/auth/jwt/jwt_test.go b/command/agent/auth/jwt/jwt_test.go index 8e9a2ae86..eb278dd01 100644 --- a/command/agent/auth/jwt/jwt_test.go +++ b/command/agent/auth/jwt/jwt_test.go @@ -165,3 +165,95 @@ func TestDeleteAfterReading(t *testing.T) { } } } + +func TestDeleteAfterReadingSymlink(t *testing.T) { + for _, tc := range map[string]struct { + configValue string + shouldDelete bool + removeJWTFollowsSymlinks bool + }{ + "default": { + "", + true, + false, + }, + "explicit true": { + "true", + true, + false, + }, + "false": { + "false", + false, + false, + }, + "default + removeJWTFollowsSymlinks": { + "", + true, + true, + }, + "explicit true + removeJWTFollowsSymlinks": { + "true", + true, + true, + }, + "false + removeJWTFollowsSymlinks": { + "false", + false, + true, + }, + } { + rootDir, err := os.MkdirTemp("", "vault-agent-jwt-auth-test") + if err != nil { + t.Fatalf("failed to create temp dir: %s", err) + } + defer os.RemoveAll(rootDir) + tokenPath := path.Join(rootDir, "token") + err = os.WriteFile(tokenPath, []byte("test"), 0o644) + if err != nil { + t.Fatal(err) + } + + symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.") + if err != nil { + t.Fatal(err) + } + symlinkName := symlink.Name() + symlink.Close() + os.Remove(symlinkName) + os.Symlink(tokenPath, symlinkName) + + config := &auth.AuthConfig{ + Config: map[string]interface{}{ + "path": symlinkName, + "role": "unusedrole", + }, + Logger: hclog.Default(), + } + if tc.configValue != "" { + config.Config["remove_jwt_after_reading"] = tc.configValue + } + config.Config["remove_jwt_follows_symlinks"] = tc.removeJWTFollowsSymlinks + + jwtAuth, err := NewJWTAuthMethod(config) + if err != nil { + t.Fatal(err) + } + + jwtAuth.(*jwtMethod).ingressToken() + + pathToCheck := symlinkName + if tc.removeJWTFollowsSymlinks { + pathToCheck = tokenPath + } + if _, err := os.Lstat(pathToCheck); tc.shouldDelete { + if err == nil || !os.IsNotExist(err) { + t.Fatal(err) + } + } else { + if err != nil { + t.Fatal(err) + } + } + } +} diff --git a/command/agent/jwt_end_to_end_test.go b/command/agent/jwt_end_to_end_test.go index c2d74d9f3..cf3824f3a 100644 --- a/command/agent/jwt_end_to_end_test.go +++ b/command/agent/jwt_end_to_end_test.go @@ -3,7 +3,7 @@ package agent import ( "context" "encoding/json" - "io/ioutil" + "fmt" "os" "testing" "time" @@ -24,11 +24,32 @@ import ( ) func TestJWTEndToEnd(t *testing.T) { - testJWTEndToEnd(t, false) - testJWTEndToEnd(t, true) + t.Parallel() + testCases := []struct { + ahWrapping bool + useSymlink bool + removeJWTAfterReading bool + }{ + {false, false, false}, + {true, false, false}, + {false, true, false}, + {true, true, false}, + {false, false, true}, + {true, false, true}, + {false, true, true}, + {true, true, true}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(fmt.Sprintf("ahWrapping=%v, useSymlink=%v, removeJWTAfterReading=%v", tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading), func(t *testing.T) { + t.Parallel() + testJWTEndToEnd(t, tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading) + }) + } } -func testJWTEndToEnd(t *testing.T, ahWrapping bool) { +func testJWTEndToEnd(t *testing.T, ahWrapping, useSymlink, removeJWTAfterReading bool) { logger := logging.NewVaultLogger(hclog.Trace) coreConfig := &vault.CoreConfig{ Logger: logger, @@ -83,16 +104,24 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { // We close these right away because we're just basically testing // permissions and finding a usable file name - inf, err := ioutil.TempFile("", "auth.jwt.test.") + inf, err := os.CreateTemp("", "auth.jwt.test.") if err != nil { t.Fatal(err) } in := inf.Name() inf.Close() os.Remove(in) + symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.") + if err != nil { + t.Fatal(err) + } + symlinkName := symlink.Name() + symlink.Close() + os.Remove(symlinkName) + os.Symlink(in, symlinkName) t.Logf("input: %s", in) - ouf, err := ioutil.TempFile("", "auth.tokensink.test.") + ouf, err := os.CreateTemp("", "auth.tokensink.test.") if err != nil { t.Fatal(err) } @@ -101,7 +130,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { os.Remove(out) t.Logf("output: %s", out) - dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.") + dhpathf, err := os.CreateTemp("", "auth.dhpath.test.") if err != nil { t.Fatal(err) } @@ -116,7 +145,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { if err != nil { t.Fatal(err) } - if err := ioutil.WriteFile(dhpath, mPubKey, 0o600); err != nil { + if err := os.WriteFile(dhpath, mPubKey, 0o600); err != nil { t.Fatal(err) } else { logger.Trace("wrote dh param file", "path", dhpath) @@ -124,12 +153,21 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + var fileNameToUseAsPath string + if useSymlink { + fileNameToUseAsPath = symlinkName + } else { + fileNameToUseAsPath = in + } am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{ Logger: logger.Named("auth.jwt"), MountPath: "auth/jwt", Config: map[string]interface{}{ - "path": in, - "role": "test", + "path": fileNameToUseAsPath, + "role": "test", + "remove_jwt_after_reading": removeJWTAfterReading, + "remove_jwt_follows_symlinks": true, + "jwt_read_period": "0.5s", }, }) if err != nil { @@ -225,7 +263,8 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { // Get a token jwtToken, _ := GetTestJWT(t) - if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil { + + if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil { t.Fatal(err) } else { logger.Trace("wrote test jwt", "path", in) @@ -237,13 +276,29 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { if time.Now().After(timeout) { t.Fatal("did not find a written token after timeout") } - val, err := ioutil.ReadFile(out) + val, err := os.ReadFile(out) if err == nil { os.Remove(out) if len(val) == 0 { t.Fatal("written token was empty") } + // First, ensure JWT has been removed + if removeJWTAfterReading { + _, err = os.Stat(in) + if err == nil { + t.Fatal("no error returned from stat, indicating the jwt is still present") + } + if !os.IsNotExist(err) { + t.Fatalf("unexpected error: %v", err) + } + } else { + _, err := os.Stat(in) + if err != nil { + t.Fatal("JWT file removed despite removeJWTAfterReading being set to false") + } + } + // First decrypt it resp := new(dhutil.Envelope) if err := jsonutil.DecodeJSON(val, resp); err != nil { @@ -336,7 +391,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { // Get another token to test the backend pushing the need to authenticate // to the handler jwtToken, _ = GetTestJWT(t) - if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil { + if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil { t.Fatal(err) } diff --git a/command/agent/testing.go b/command/agent/testing.go index d4de988a9..fc8374de9 100644 --- a/command/agent/testing.go +++ b/command/agent/testing.go @@ -6,7 +6,6 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - "io/ioutil" "os" "testing" "time" @@ -61,7 +60,7 @@ func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) { } func readToken(fileName string) (*logical.HTTPWrapInfo, error) { - b, err := ioutil.ReadFile(fileName) + b, err := os.ReadFile(fileName) if err != nil { return nil, err } diff --git a/website/content/docs/agent/autoauth/methods/jwt.mdx b/website/content/docs/agent/autoauth/methods/jwt.mdx index da290dda6..1d2a0861a 100644 --- a/website/content/docs/agent/autoauth/methods/jwt.mdx +++ b/website/content/docs/agent/autoauth/methods/jwt.mdx @@ -18,3 +18,13 @@ method](/vault/docs/auth/jwt). - `remove_jwt_after_reading` `(bool: optional, defaults to true)` - This can be set to `false` to disable the default behavior of removing the JWT after it's been read. + +- `remove_jwt_follows_symlinks` `(bool: optional, defaults to false)` - +This can be set to `true` to follow symlinks when removing the JWT after it has been read +when executing the `remove_jwt_after_reading` behaviour. If set to false, it will delete +the symlink, not the JWT. Does nothing if `remove_jwt_after_reading` is false. + +- `jwt_read_period` `(duration: "0.5s", optional)` - The duration after which +Agent will attempt to read the JWT stored at `path`. Defaults to `1m` if +`remove_jwt_after_reading` is set to `true`, or `0.5s` otherwise. +Uses [duration format strings](/docs/concepts/duration-format).