VAULT-12798 Correct removal behaviour when JWT is symlink (#18863)
* VAULT-12798 testing for jwt symlinks * VAULT-12798 Add testing of jwt removal * VAULT-12798 Update docs for clarity * VAULT-12798 Small change, and changelog * VAULT-12798 Lstat -> Stat * VAULT-12798 remove forgotten comment * VAULT-12798 small refactor, add new config item * VAULT-12798 Require opt-in config for following symlinks for JWT deletion * VAULT-12798 change changelog
This commit is contained in:
parent
e19dc98016
commit
85f845c3e0
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
agent: JWT auto-auth has a new config option, `remove_jwt_follows_symlinks` (default: false), that, if set to true will now remove the JWT, instead of the symlink to the JWT, if a symlink to a JWT has been provided in the `path` option, and the `remove_jwt_after_reading` config option is set to true (default).
|
||||
```
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -18,19 +19,20 @@ import (
|
|||
)
|
||||
|
||||
type jwtMethod struct {
|
||||
logger hclog.Logger
|
||||
path string
|
||||
mountPath string
|
||||
role string
|
||||
removeJWTAfterReading bool
|
||||
credsFound chan struct{}
|
||||
watchCh chan string
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
credSuccessGate chan struct{}
|
||||
ticker *time.Ticker
|
||||
once *sync.Once
|
||||
latestToken *atomic.Value
|
||||
logger hclog.Logger
|
||||
path string
|
||||
mountPath string
|
||||
role string
|
||||
removeJWTAfterReading bool
|
||||
removeJWTFollowsSymlinks bool
|
||||
credsFound chan struct{}
|
||||
watchCh chan string
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
credSuccessGate chan struct{}
|
||||
ticker *time.Ticker
|
||||
once *sync.Once
|
||||
latestToken *atomic.Value
|
||||
}
|
||||
|
||||
// NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod
|
||||
|
@ -83,6 +85,14 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
j.removeJWTAfterReading = removeJWTAfterReading
|
||||
}
|
||||
|
||||
if removeJWTFollowsSymlinksRaw, ok := conf.Config["remove_jwt_follows_symlinks"]; ok {
|
||||
removeJWTFollowsSymlinks, err := parseutil.ParseBool(removeJWTFollowsSymlinksRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing 'remove_jwt_follows_symlinks' value: %w", err)
|
||||
}
|
||||
j.removeJWTFollowsSymlinks = removeJWTFollowsSymlinks
|
||||
}
|
||||
|
||||
switch {
|
||||
case j.path == "":
|
||||
return nil, errors.New("'path' value is empty")
|
||||
|
@ -90,13 +100,24 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return nil, errors.New("'role' value is empty")
|
||||
}
|
||||
|
||||
// If we don't delete the JWT after reading, use a slower reload period,
|
||||
// otherwise we would re-read the whole file every 500ms, instead of just
|
||||
// doing a stat on the file every 500ms.
|
||||
// Default readPeriod
|
||||
readPeriod := 1 * time.Minute
|
||||
if j.removeJWTAfterReading {
|
||||
readPeriod = 500 * time.Millisecond
|
||||
|
||||
if jwtReadPeriodRaw, ok := conf.Config["jwt_read_period"]; ok {
|
||||
jwtReadPeriod, err := parseutil.ParseDurationSecond(jwtReadPeriodRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing 'jwt_read_period' value: %w", err)
|
||||
}
|
||||
readPeriod = jwtReadPeriod
|
||||
} else {
|
||||
// If we don't delete the JWT after reading, use a slower reload period,
|
||||
// otherwise we would re-read the whole file every 500ms, instead of just
|
||||
// doing a stat on the file every 500ms.
|
||||
if j.removeJWTAfterReading {
|
||||
readPeriod = 500 * time.Millisecond
|
||||
}
|
||||
}
|
||||
|
||||
j.ticker = time.NewTicker(readPeriod)
|
||||
|
||||
go j.runWatcher()
|
||||
|
@ -147,8 +168,8 @@ func (j *jwtMethod) runWatcher() {
|
|||
|
||||
case <-j.credSuccessGate:
|
||||
// We only start the next loop once we're initially successful,
|
||||
// since at startup Authenticate will be called and we don't want
|
||||
// to end up immediately reauthenticating by having found a new
|
||||
// since at startup Authenticate will be called, and we don't want
|
||||
// to end up immediately re-authenticating by having found a new
|
||||
// value
|
||||
}
|
||||
|
||||
|
@ -182,11 +203,27 @@ func (j *jwtMethod) ingressToken() {
|
|||
// Check that the path refers to a file.
|
||||
// If it's a symlink, it could still be a symlink to a directory,
|
||||
// but os.ReadFile below will return a descriptive error.
|
||||
evalSymlinkPath := j.path
|
||||
switch mode := fi.Mode(); {
|
||||
case mode.IsRegular():
|
||||
// regular file
|
||||
case mode&fs.ModeSymlink != 0:
|
||||
// symlink
|
||||
// If our file path is a symlink, we should also return early (like above) without error
|
||||
// if the file that is linked to is not present, otherwise we will error when trying
|
||||
// to read that file by following the link in the os.ReadFile call.
|
||||
evalSymlinkPath, err = filepath.EvalSymlinks(j.path)
|
||||
if err != nil {
|
||||
j.logger.Error("error encountered evaluating symlinks", "error", err)
|
||||
return
|
||||
}
|
||||
_, err := os.Stat(evalSymlinkPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
j.logger.Error("error encountered stat'ing jwt file after evaluating symlinks", "error", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
j.logger.Error("jwt file is not a regular file or symlink")
|
||||
return
|
||||
|
@ -207,7 +244,13 @@ func (j *jwtMethod) ingressToken() {
|
|||
}
|
||||
|
||||
if j.removeJWTAfterReading {
|
||||
if err := os.Remove(j.path); err != nil {
|
||||
pathToRemove := j.path
|
||||
if j.removeJWTFollowsSymlinks {
|
||||
// If removeJWTFollowsSymlinks is set, we follow the symlink and delete the jwt,
|
||||
// not just the symlink that links to the jwt
|
||||
pathToRemove = evalSymlinkPath
|
||||
}
|
||||
if err := os.Remove(pathToRemove); err != nil {
|
||||
j.logger.Error("error removing jwt file", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -165,3 +165,95 @@ func TestDeleteAfterReading(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAfterReadingSymlink(t *testing.T) {
|
||||
for _, tc := range map[string]struct {
|
||||
configValue string
|
||||
shouldDelete bool
|
||||
removeJWTFollowsSymlinks bool
|
||||
}{
|
||||
"default": {
|
||||
"",
|
||||
true,
|
||||
false,
|
||||
},
|
||||
"explicit true": {
|
||||
"true",
|
||||
true,
|
||||
false,
|
||||
},
|
||||
"false": {
|
||||
"false",
|
||||
false,
|
||||
false,
|
||||
},
|
||||
"default + removeJWTFollowsSymlinks": {
|
||||
"",
|
||||
true,
|
||||
true,
|
||||
},
|
||||
"explicit true + removeJWTFollowsSymlinks": {
|
||||
"true",
|
||||
true,
|
||||
true,
|
||||
},
|
||||
"false + removeJWTFollowsSymlinks": {
|
||||
"false",
|
||||
false,
|
||||
true,
|
||||
},
|
||||
} {
|
||||
rootDir, err := os.MkdirTemp("", "vault-agent-jwt-auth-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %s", err)
|
||||
}
|
||||
defer os.RemoveAll(rootDir)
|
||||
tokenPath := path.Join(rootDir, "token")
|
||||
err = os.WriteFile(tokenPath, []byte("test"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
symlinkName := symlink.Name()
|
||||
symlink.Close()
|
||||
os.Remove(symlinkName)
|
||||
os.Symlink(tokenPath, symlinkName)
|
||||
|
||||
config := &auth.AuthConfig{
|
||||
Config: map[string]interface{}{
|
||||
"path": symlinkName,
|
||||
"role": "unusedrole",
|
||||
},
|
||||
Logger: hclog.Default(),
|
||||
}
|
||||
if tc.configValue != "" {
|
||||
config.Config["remove_jwt_after_reading"] = tc.configValue
|
||||
}
|
||||
config.Config["remove_jwt_follows_symlinks"] = tc.removeJWTFollowsSymlinks
|
||||
|
||||
jwtAuth, err := NewJWTAuthMethod(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jwtAuth.(*jwtMethod).ingressToken()
|
||||
|
||||
pathToCheck := symlinkName
|
||||
if tc.removeJWTFollowsSymlinks {
|
||||
pathToCheck = tokenPath
|
||||
}
|
||||
if _, err := os.Lstat(pathToCheck); tc.shouldDelete {
|
||||
if err == nil || !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package agent
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -24,11 +24,32 @@ import (
|
|||
)
|
||||
|
||||
func TestJWTEndToEnd(t *testing.T) {
|
||||
testJWTEndToEnd(t, false)
|
||||
testJWTEndToEnd(t, true)
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
ahWrapping bool
|
||||
useSymlink bool
|
||||
removeJWTAfterReading bool
|
||||
}{
|
||||
{false, false, false},
|
||||
{true, false, false},
|
||||
{false, true, false},
|
||||
{true, true, false},
|
||||
{false, false, true},
|
||||
{true, false, true},
|
||||
{false, true, true},
|
||||
{true, true, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(fmt.Sprintf("ahWrapping=%v, useSymlink=%v, removeJWTAfterReading=%v", tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testJWTEndToEnd(t, tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
||||
func testJWTEndToEnd(t *testing.T, ahWrapping, useSymlink, removeJWTAfterReading bool) {
|
||||
logger := logging.NewVaultLogger(hclog.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
Logger: logger,
|
||||
|
@ -83,16 +104,24 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
|
||||
// We close these right away because we're just basically testing
|
||||
// permissions and finding a usable file name
|
||||
inf, err := ioutil.TempFile("", "auth.jwt.test.")
|
||||
inf, err := os.CreateTemp("", "auth.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
in := inf.Name()
|
||||
inf.Close()
|
||||
os.Remove(in)
|
||||
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
symlinkName := symlink.Name()
|
||||
symlink.Close()
|
||||
os.Remove(symlinkName)
|
||||
os.Symlink(in, symlinkName)
|
||||
t.Logf("input: %s", in)
|
||||
|
||||
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
|
||||
ouf, err := os.CreateTemp("", "auth.tokensink.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -101,7 +130,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
os.Remove(out)
|
||||
t.Logf("output: %s", out)
|
||||
|
||||
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
|
||||
dhpathf, err := os.CreateTemp("", "auth.dhpath.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -116,7 +145,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := ioutil.WriteFile(dhpath, mPubKey, 0o600); err != nil {
|
||||
if err := os.WriteFile(dhpath, mPubKey, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote dh param file", "path", dhpath)
|
||||
|
@ -124,12 +153,21 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
|
||||
var fileNameToUseAsPath string
|
||||
if useSymlink {
|
||||
fileNameToUseAsPath = symlinkName
|
||||
} else {
|
||||
fileNameToUseAsPath = in
|
||||
}
|
||||
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
|
||||
Logger: logger.Named("auth.jwt"),
|
||||
MountPath: "auth/jwt",
|
||||
Config: map[string]interface{}{
|
||||
"path": in,
|
||||
"role": "test",
|
||||
"path": fileNameToUseAsPath,
|
||||
"role": "test",
|
||||
"remove_jwt_after_reading": removeJWTAfterReading,
|
||||
"remove_jwt_follows_symlinks": true,
|
||||
"jwt_read_period": "0.5s",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -225,7 +263,8 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
|
||||
// Get a token
|
||||
jwtToken, _ := GetTestJWT(t)
|
||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||
|
||||
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test jwt", "path", in)
|
||||
|
@ -237,13 +276,29 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
if time.Now().After(timeout) {
|
||||
t.Fatal("did not find a written token after timeout")
|
||||
}
|
||||
val, err := ioutil.ReadFile(out)
|
||||
val, err := os.ReadFile(out)
|
||||
if err == nil {
|
||||
os.Remove(out)
|
||||
if len(val) == 0 {
|
||||
t.Fatal("written token was empty")
|
||||
}
|
||||
|
||||
// First, ensure JWT has been removed
|
||||
if removeJWTAfterReading {
|
||||
_, err = os.Stat(in)
|
||||
if err == nil {
|
||||
t.Fatal("no error returned from stat, indicating the jwt is still present")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
_, err := os.Stat(in)
|
||||
if err != nil {
|
||||
t.Fatal("JWT file removed despite removeJWTAfterReading being set to false")
|
||||
}
|
||||
}
|
||||
|
||||
// First decrypt it
|
||||
resp := new(dhutil.Envelope)
|
||||
if err := jsonutil.DecodeJSON(val, resp); err != nil {
|
||||
|
@ -336,7 +391,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
// Get another token to test the backend pushing the need to authenticate
|
||||
// to the handler
|
||||
jwtToken, _ = GetTestJWT(t)
|
||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -61,7 +60,7 @@ func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
|
|||
}
|
||||
|
||||
func readToken(fileName string) (*logical.HTTPWrapInfo, error) {
|
||||
b, err := ioutil.ReadFile(fileName)
|
||||
b, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -18,3 +18,13 @@ method](/vault/docs/auth/jwt).
|
|||
- `remove_jwt_after_reading` `(bool: optional, defaults to true)` -
|
||||
This can be set to `false` to disable the default behavior of removing the
|
||||
JWT after it's been read.
|
||||
|
||||
- `remove_jwt_follows_symlinks` `(bool: optional, defaults to false)` -
|
||||
This can be set to `true` to follow symlinks when removing the JWT after it has been read
|
||||
when executing the `remove_jwt_after_reading` behaviour. If set to false, it will delete
|
||||
the symlink, not the JWT. Does nothing if `remove_jwt_after_reading` is false.
|
||||
|
||||
- `jwt_read_period` `(duration: "0.5s", optional)` - The duration after which
|
||||
Agent will attempt to read the JWT stored at `path`. Defaults to `1m` if
|
||||
`remove_jwt_after_reading` is set to `true`, or `0.5s` otherwise.
|
||||
Uses [duration format strings](/docs/concepts/duration-format).
|
||||
|
|
Loading…
Reference in New Issue