VAULT-12798 Correct removal behaviour when JWT is symlink (#18863)
* VAULT-12798 testing for jwt symlinks * VAULT-12798 Add testing of jwt removal * VAULT-12798 Update docs for clarity * VAULT-12798 Small change, and changelog * VAULT-12798 Lstat -> Stat * VAULT-12798 remove forgotten comment * VAULT-12798 small refactor, add new config item * VAULT-12798 Require opt-in config for following symlinks for JWT deletion * VAULT-12798 change changelog
This commit is contained in:
parent
e19dc98016
commit
85f845c3e0
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:improvement
|
||||||
|
agent: JWT auto-auth has a new config option, `remove_jwt_follows_symlinks` (default: false), that, if set to true will now remove the JWT, instead of the symlink to the JWT, if a symlink to a JWT has been provided in the `path` option, and the `remove_jwt_after_reading` config option is set to true (default).
|
||||||
|
```
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -23,6 +24,7 @@ type jwtMethod struct {
|
||||||
mountPath string
|
mountPath string
|
||||||
role string
|
role string
|
||||||
removeJWTAfterReading bool
|
removeJWTAfterReading bool
|
||||||
|
removeJWTFollowsSymlinks bool
|
||||||
credsFound chan struct{}
|
credsFound chan struct{}
|
||||||
watchCh chan string
|
watchCh chan string
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
|
@ -83,6 +85,14 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
||||||
j.removeJWTAfterReading = removeJWTAfterReading
|
j.removeJWTAfterReading = removeJWTAfterReading
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if removeJWTFollowsSymlinksRaw, ok := conf.Config["remove_jwt_follows_symlinks"]; ok {
|
||||||
|
removeJWTFollowsSymlinks, err := parseutil.ParseBool(removeJWTFollowsSymlinksRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing 'remove_jwt_follows_symlinks' value: %w", err)
|
||||||
|
}
|
||||||
|
j.removeJWTFollowsSymlinks = removeJWTFollowsSymlinks
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case j.path == "":
|
case j.path == "":
|
||||||
return nil, errors.New("'path' value is empty")
|
return nil, errors.New("'path' value is empty")
|
||||||
|
@ -90,13 +100,24 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
||||||
return nil, errors.New("'role' value is empty")
|
return nil, errors.New("'role' value is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default readPeriod
|
||||||
|
readPeriod := 1 * time.Minute
|
||||||
|
|
||||||
|
if jwtReadPeriodRaw, ok := conf.Config["jwt_read_period"]; ok {
|
||||||
|
jwtReadPeriod, err := parseutil.ParseDurationSecond(jwtReadPeriodRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing 'jwt_read_period' value: %w", err)
|
||||||
|
}
|
||||||
|
readPeriod = jwtReadPeriod
|
||||||
|
} else {
|
||||||
// If we don't delete the JWT after reading, use a slower reload period,
|
// If we don't delete the JWT after reading, use a slower reload period,
|
||||||
// otherwise we would re-read the whole file every 500ms, instead of just
|
// otherwise we would re-read the whole file every 500ms, instead of just
|
||||||
// doing a stat on the file every 500ms.
|
// doing a stat on the file every 500ms.
|
||||||
readPeriod := 1 * time.Minute
|
|
||||||
if j.removeJWTAfterReading {
|
if j.removeJWTAfterReading {
|
||||||
readPeriod = 500 * time.Millisecond
|
readPeriod = 500 * time.Millisecond
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
j.ticker = time.NewTicker(readPeriod)
|
j.ticker = time.NewTicker(readPeriod)
|
||||||
|
|
||||||
go j.runWatcher()
|
go j.runWatcher()
|
||||||
|
@ -147,8 +168,8 @@ func (j *jwtMethod) runWatcher() {
|
||||||
|
|
||||||
case <-j.credSuccessGate:
|
case <-j.credSuccessGate:
|
||||||
// We only start the next loop once we're initially successful,
|
// We only start the next loop once we're initially successful,
|
||||||
// since at startup Authenticate will be called and we don't want
|
// since at startup Authenticate will be called, and we don't want
|
||||||
// to end up immediately reauthenticating by having found a new
|
// to end up immediately re-authenticating by having found a new
|
||||||
// value
|
// value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,11 +203,27 @@ func (j *jwtMethod) ingressToken() {
|
||||||
// Check that the path refers to a file.
|
// Check that the path refers to a file.
|
||||||
// If it's a symlink, it could still be a symlink to a directory,
|
// If it's a symlink, it could still be a symlink to a directory,
|
||||||
// but os.ReadFile below will return a descriptive error.
|
// but os.ReadFile below will return a descriptive error.
|
||||||
|
evalSymlinkPath := j.path
|
||||||
switch mode := fi.Mode(); {
|
switch mode := fi.Mode(); {
|
||||||
case mode.IsRegular():
|
case mode.IsRegular():
|
||||||
// regular file
|
// regular file
|
||||||
case mode&fs.ModeSymlink != 0:
|
case mode&fs.ModeSymlink != 0:
|
||||||
// symlink
|
// If our file path is a symlink, we should also return early (like above) without error
|
||||||
|
// if the file that is linked to is not present, otherwise we will error when trying
|
||||||
|
// to read that file by following the link in the os.ReadFile call.
|
||||||
|
evalSymlinkPath, err = filepath.EvalSymlinks(j.path)
|
||||||
|
if err != nil {
|
||||||
|
j.logger.Error("error encountered evaluating symlinks", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err := os.Stat(evalSymlinkPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
j.logger.Error("error encountered stat'ing jwt file after evaluating symlinks", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
j.logger.Error("jwt file is not a regular file or symlink")
|
j.logger.Error("jwt file is not a regular file or symlink")
|
||||||
return
|
return
|
||||||
|
@ -207,7 +244,13 @@ func (j *jwtMethod) ingressToken() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if j.removeJWTAfterReading {
|
if j.removeJWTAfterReading {
|
||||||
if err := os.Remove(j.path); err != nil {
|
pathToRemove := j.path
|
||||||
|
if j.removeJWTFollowsSymlinks {
|
||||||
|
// If removeJWTFollowsSymlinks is set, we follow the symlink and delete the jwt,
|
||||||
|
// not just the symlink that links to the jwt
|
||||||
|
pathToRemove = evalSymlinkPath
|
||||||
|
}
|
||||||
|
if err := os.Remove(pathToRemove); err != nil {
|
||||||
j.logger.Error("error removing jwt file", "error", err)
|
j.logger.Error("error removing jwt file", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -165,3 +165,95 @@ func TestDeleteAfterReading(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteAfterReadingSymlink(t *testing.T) {
|
||||||
|
for _, tc := range map[string]struct {
|
||||||
|
configValue string
|
||||||
|
shouldDelete bool
|
||||||
|
removeJWTFollowsSymlinks bool
|
||||||
|
}{
|
||||||
|
"default": {
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
"explicit true": {
|
||||||
|
"true",
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
"false": {
|
||||||
|
"false",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
"default + removeJWTFollowsSymlinks": {
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
"explicit true + removeJWTFollowsSymlinks": {
|
||||||
|
"true",
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
"false + removeJWTFollowsSymlinks": {
|
||||||
|
"false",
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
rootDir, err := os.MkdirTemp("", "vault-agent-jwt-auth-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %s", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(rootDir)
|
||||||
|
tokenPath := path.Join(rootDir, "token")
|
||||||
|
err = os.WriteFile(tokenPath, []byte("test"), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
symlinkName := symlink.Name()
|
||||||
|
symlink.Close()
|
||||||
|
os.Remove(symlinkName)
|
||||||
|
os.Symlink(tokenPath, symlinkName)
|
||||||
|
|
||||||
|
config := &auth.AuthConfig{
|
||||||
|
Config: map[string]interface{}{
|
||||||
|
"path": symlinkName,
|
||||||
|
"role": "unusedrole",
|
||||||
|
},
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
}
|
||||||
|
if tc.configValue != "" {
|
||||||
|
config.Config["remove_jwt_after_reading"] = tc.configValue
|
||||||
|
}
|
||||||
|
config.Config["remove_jwt_follows_symlinks"] = tc.removeJWTFollowsSymlinks
|
||||||
|
|
||||||
|
jwtAuth, err := NewJWTAuthMethod(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtAuth.(*jwtMethod).ingressToken()
|
||||||
|
|
||||||
|
pathToCheck := symlinkName
|
||||||
|
if tc.removeJWTFollowsSymlinks {
|
||||||
|
pathToCheck = tokenPath
|
||||||
|
}
|
||||||
|
if _, err := os.Lstat(pathToCheck); tc.shouldDelete {
|
||||||
|
if err == nil || !os.IsNotExist(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ package agent
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -24,11 +24,32 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTEndToEnd(t *testing.T) {
|
func TestJWTEndToEnd(t *testing.T) {
|
||||||
testJWTEndToEnd(t, false)
|
t.Parallel()
|
||||||
testJWTEndToEnd(t, true)
|
testCases := []struct {
|
||||||
|
ahWrapping bool
|
||||||
|
useSymlink bool
|
||||||
|
removeJWTAfterReading bool
|
||||||
|
}{
|
||||||
|
{false, false, false},
|
||||||
|
{true, false, false},
|
||||||
|
{false, true, false},
|
||||||
|
{true, true, false},
|
||||||
|
{false, false, true},
|
||||||
|
{true, false, true},
|
||||||
|
{false, true, true},
|
||||||
|
{true, true, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // capture range variable
|
||||||
|
t.Run(fmt.Sprintf("ahWrapping=%v, useSymlink=%v, removeJWTAfterReading=%v", tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading), func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testJWTEndToEnd(t, tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
func testJWTEndToEnd(t *testing.T, ahWrapping, useSymlink, removeJWTAfterReading bool) {
|
||||||
logger := logging.NewVaultLogger(hclog.Trace)
|
logger := logging.NewVaultLogger(hclog.Trace)
|
||||||
coreConfig := &vault.CoreConfig{
|
coreConfig := &vault.CoreConfig{
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
|
@ -83,16 +104,24 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
|
|
||||||
// We close these right away because we're just basically testing
|
// We close these right away because we're just basically testing
|
||||||
// permissions and finding a usable file name
|
// permissions and finding a usable file name
|
||||||
inf, err := ioutil.TempFile("", "auth.jwt.test.")
|
inf, err := os.CreateTemp("", "auth.jwt.test.")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
in := inf.Name()
|
in := inf.Name()
|
||||||
inf.Close()
|
inf.Close()
|
||||||
os.Remove(in)
|
os.Remove(in)
|
||||||
|
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
symlinkName := symlink.Name()
|
||||||
|
symlink.Close()
|
||||||
|
os.Remove(symlinkName)
|
||||||
|
os.Symlink(in, symlinkName)
|
||||||
t.Logf("input: %s", in)
|
t.Logf("input: %s", in)
|
||||||
|
|
||||||
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
|
ouf, err := os.CreateTemp("", "auth.tokensink.test.")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -101,7 +130,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
os.Remove(out)
|
os.Remove(out)
|
||||||
t.Logf("output: %s", out)
|
t.Logf("output: %s", out)
|
||||||
|
|
||||||
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
|
dhpathf, err := os.CreateTemp("", "auth.dhpath.test.")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -116,7 +145,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := ioutil.WriteFile(dhpath, mPubKey, 0o600); err != nil {
|
if err := os.WriteFile(dhpath, mPubKey, 0o600); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else {
|
} else {
|
||||||
logger.Trace("wrote dh param file", "path", dhpath)
|
logger.Trace("wrote dh param file", "path", dhpath)
|
||||||
|
@ -124,12 +153,21 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
|
||||||
|
var fileNameToUseAsPath string
|
||||||
|
if useSymlink {
|
||||||
|
fileNameToUseAsPath = symlinkName
|
||||||
|
} else {
|
||||||
|
fileNameToUseAsPath = in
|
||||||
|
}
|
||||||
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
|
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
|
||||||
Logger: logger.Named("auth.jwt"),
|
Logger: logger.Named("auth.jwt"),
|
||||||
MountPath: "auth/jwt",
|
MountPath: "auth/jwt",
|
||||||
Config: map[string]interface{}{
|
Config: map[string]interface{}{
|
||||||
"path": in,
|
"path": fileNameToUseAsPath,
|
||||||
"role": "test",
|
"role": "test",
|
||||||
|
"remove_jwt_after_reading": removeJWTAfterReading,
|
||||||
|
"remove_jwt_follows_symlinks": true,
|
||||||
|
"jwt_read_period": "0.5s",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -225,7 +263,8 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
|
|
||||||
// Get a token
|
// Get a token
|
||||||
jwtToken, _ := GetTestJWT(t)
|
jwtToken, _ := GetTestJWT(t)
|
||||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
|
||||||
|
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else {
|
} else {
|
||||||
logger.Trace("wrote test jwt", "path", in)
|
logger.Trace("wrote test jwt", "path", in)
|
||||||
|
@ -237,13 +276,29 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
if time.Now().After(timeout) {
|
if time.Now().After(timeout) {
|
||||||
t.Fatal("did not find a written token after timeout")
|
t.Fatal("did not find a written token after timeout")
|
||||||
}
|
}
|
||||||
val, err := ioutil.ReadFile(out)
|
val, err := os.ReadFile(out)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
os.Remove(out)
|
os.Remove(out)
|
||||||
if len(val) == 0 {
|
if len(val) == 0 {
|
||||||
t.Fatal("written token was empty")
|
t.Fatal("written token was empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// First, ensure JWT has been removed
|
||||||
|
if removeJWTAfterReading {
|
||||||
|
_, err = os.Stat(in)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("no error returned from stat, indicating the jwt is still present")
|
||||||
|
}
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err := os.Stat(in)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("JWT file removed despite removeJWTAfterReading being set to false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// First decrypt it
|
// First decrypt it
|
||||||
resp := new(dhutil.Envelope)
|
resp := new(dhutil.Envelope)
|
||||||
if err := jsonutil.DecodeJSON(val, resp); err != nil {
|
if err := jsonutil.DecodeJSON(val, resp); err != nil {
|
||||||
|
@ -336,7 +391,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||||
// Get another token to test the backend pushing the need to authenticate
|
// Get another token to test the backend pushing the need to authenticate
|
||||||
// to the handler
|
// to the handler
|
||||||
jwtToken, _ = GetTestJWT(t)
|
jwtToken, _ = GetTestJWT(t)
|
||||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -61,7 +60,7 @@ func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func readToken(fileName string) (*logical.HTTPWrapInfo, error) {
|
func readToken(fileName string) (*logical.HTTPWrapInfo, error) {
|
||||||
b, err := ioutil.ReadFile(fileName)
|
b, err := os.ReadFile(fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,3 +18,13 @@ method](/vault/docs/auth/jwt).
|
||||||
- `remove_jwt_after_reading` `(bool: optional, defaults to true)` -
|
- `remove_jwt_after_reading` `(bool: optional, defaults to true)` -
|
||||||
This can be set to `false` to disable the default behavior of removing the
|
This can be set to `false` to disable the default behavior of removing the
|
||||||
JWT after it's been read.
|
JWT after it's been read.
|
||||||
|
|
||||||
|
- `remove_jwt_follows_symlinks` `(bool: optional, defaults to false)` -
|
||||||
|
This can be set to `true` to follow symlinks when removing the JWT after it has been read
|
||||||
|
when executing the `remove_jwt_after_reading` behaviour. If set to false, it will delete
|
||||||
|
the symlink, not the JWT. Does nothing if `remove_jwt_after_reading` is false.
|
||||||
|
|
||||||
|
- `jwt_read_period` `(duration: "0.5s", optional)` - The duration after which
|
||||||
|
Agent will attempt to read the JWT stored at `path`. Defaults to `1m` if
|
||||||
|
`remove_jwt_after_reading` is set to `true`, or `0.5s` otherwise.
|
||||||
|
Uses [duration format strings](/docs/concepts/duration-format).
|
||||||
|
|
Loading…
Reference in New Issue