VAULT-12798 Correct removal behaviour when JWT is symlink (#18863)

* VAULT-12798 testing for jwt symlinks

* VAULT-12798 Add testing of jwt removal

* VAULT-12798 Update docs for clarity

* VAULT-12798 Small change, and changelog

* VAULT-12798 Lstat -> Stat

* VAULT-12798 remove forgotten comment

* VAULT-12798 small refactor, add new config item

* VAULT-12798 Require opt-in config for following symlinks for JWT deletion

* VAULT-12798 change changelog
This commit is contained in:
Violet Hynes 2023-03-14 15:44:19 -04:00 committed by GitHub
parent e19dc98016
commit 85f845c3e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 239 additions and 37 deletions

3
changelog/18863.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
agent: JWT auto-auth has a new config option, `remove_jwt_follows_symlinks` (default: false), that, if set to true will now remove the JWT, instead of the symlink to the JWT, if a symlink to a JWT has been provided in the `path` option, and the `remove_jwt_after_reading` config option is set to true (default).
```

View File

@ -7,6 +7,7 @@ import (
"io/fs" "io/fs"
"net/http" "net/http"
"os" "os"
"path/filepath"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -18,19 +19,20 @@ import (
) )
type jwtMethod struct { type jwtMethod struct {
logger hclog.Logger logger hclog.Logger
path string path string
mountPath string mountPath string
role string role string
removeJWTAfterReading bool removeJWTAfterReading bool
credsFound chan struct{} removeJWTFollowsSymlinks bool
watchCh chan string credsFound chan struct{}
stopCh chan struct{} watchCh chan string
doneCh chan struct{} stopCh chan struct{}
credSuccessGate chan struct{} doneCh chan struct{}
ticker *time.Ticker credSuccessGate chan struct{}
once *sync.Once ticker *time.Ticker
latestToken *atomic.Value once *sync.Once
latestToken *atomic.Value
} }
// NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod // NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod
@ -83,6 +85,14 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
j.removeJWTAfterReading = removeJWTAfterReading j.removeJWTAfterReading = removeJWTAfterReading
} }
if removeJWTFollowsSymlinksRaw, ok := conf.Config["remove_jwt_follows_symlinks"]; ok {
removeJWTFollowsSymlinks, err := parseutil.ParseBool(removeJWTFollowsSymlinksRaw)
if err != nil {
return nil, fmt.Errorf("error parsing 'remove_jwt_follows_symlinks' value: %w", err)
}
j.removeJWTFollowsSymlinks = removeJWTFollowsSymlinks
}
switch { switch {
case j.path == "": case j.path == "":
return nil, errors.New("'path' value is empty") return nil, errors.New("'path' value is empty")
@ -90,13 +100,24 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return nil, errors.New("'role' value is empty") return nil, errors.New("'role' value is empty")
} }
// If we don't delete the JWT after reading, use a slower reload period, // Default readPeriod
// otherwise we would re-read the whole file every 500ms, instead of just
// doing a stat on the file every 500ms.
readPeriod := 1 * time.Minute readPeriod := 1 * time.Minute
if j.removeJWTAfterReading {
readPeriod = 500 * time.Millisecond if jwtReadPeriodRaw, ok := conf.Config["jwt_read_period"]; ok {
jwtReadPeriod, err := parseutil.ParseDurationSecond(jwtReadPeriodRaw)
if err != nil {
return nil, fmt.Errorf("error parsing 'jwt_read_period' value: %w", err)
}
readPeriod = jwtReadPeriod
} else {
// If we don't delete the JWT after reading, use a slower reload period,
// otherwise we would re-read the whole file every 500ms, instead of just
// doing a stat on the file every 500ms.
if j.removeJWTAfterReading {
readPeriod = 500 * time.Millisecond
}
} }
j.ticker = time.NewTicker(readPeriod) j.ticker = time.NewTicker(readPeriod)
go j.runWatcher() go j.runWatcher()
@ -147,8 +168,8 @@ func (j *jwtMethod) runWatcher() {
case <-j.credSuccessGate: case <-j.credSuccessGate:
// We only start the next loop once we're initially successful, // We only start the next loop once we're initially successful,
// since at startup Authenticate will be called and we don't want // since at startup Authenticate will be called, and we don't want
// to end up immediately reauthenticating by having found a new // to end up immediately re-authenticating by having found a new
// value // value
} }
@ -182,11 +203,27 @@ func (j *jwtMethod) ingressToken() {
// Check that the path refers to a file. // Check that the path refers to a file.
// If it's a symlink, it could still be a symlink to a directory, // If it's a symlink, it could still be a symlink to a directory,
// but os.ReadFile below will return a descriptive error. // but os.ReadFile below will return a descriptive error.
evalSymlinkPath := j.path
switch mode := fi.Mode(); { switch mode := fi.Mode(); {
case mode.IsRegular(): case mode.IsRegular():
// regular file // regular file
case mode&fs.ModeSymlink != 0: case mode&fs.ModeSymlink != 0:
// symlink // If our file path is a symlink, we should also return early (like above) without error
// if the file that is linked to is not present, otherwise we will error when trying
// to read that file by following the link in the os.ReadFile call.
evalSymlinkPath, err = filepath.EvalSymlinks(j.path)
if err != nil {
j.logger.Error("error encountered evaluating symlinks", "error", err)
return
}
_, err := os.Stat(evalSymlinkPath)
if err != nil {
if os.IsNotExist(err) {
return
}
j.logger.Error("error encountered stat'ing jwt file after evaluating symlinks", "error", err)
return
}
default: default:
j.logger.Error("jwt file is not a regular file or symlink") j.logger.Error("jwt file is not a regular file or symlink")
return return
@ -207,7 +244,13 @@ func (j *jwtMethod) ingressToken() {
} }
if j.removeJWTAfterReading { if j.removeJWTAfterReading {
if err := os.Remove(j.path); err != nil { pathToRemove := j.path
if j.removeJWTFollowsSymlinks {
// If removeJWTFollowsSymlinks is set, we follow the symlink and delete the jwt,
// not just the symlink that links to the jwt
pathToRemove = evalSymlinkPath
}
if err := os.Remove(pathToRemove); err != nil {
j.logger.Error("error removing jwt file", "error", err) j.logger.Error("error removing jwt file", "error", err)
} }
} }

View File

@ -165,3 +165,95 @@ func TestDeleteAfterReading(t *testing.T) {
} }
} }
} }
func TestDeleteAfterReadingSymlink(t *testing.T) {
for _, tc := range map[string]struct {
configValue string
shouldDelete bool
removeJWTFollowsSymlinks bool
}{
"default": {
"",
true,
false,
},
"explicit true": {
"true",
true,
false,
},
"false": {
"false",
false,
false,
},
"default + removeJWTFollowsSymlinks": {
"",
true,
true,
},
"explicit true + removeJWTFollowsSymlinks": {
"true",
true,
true,
},
"false + removeJWTFollowsSymlinks": {
"false",
false,
true,
},
} {
rootDir, err := os.MkdirTemp("", "vault-agent-jwt-auth-test")
if err != nil {
t.Fatalf("failed to create temp dir: %s", err)
}
defer os.RemoveAll(rootDir)
tokenPath := path.Join(rootDir, "token")
err = os.WriteFile(tokenPath, []byte("test"), 0o644)
if err != nil {
t.Fatal(err)
}
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
if err != nil {
t.Fatal(err)
}
symlinkName := symlink.Name()
symlink.Close()
os.Remove(symlinkName)
os.Symlink(tokenPath, symlinkName)
config := &auth.AuthConfig{
Config: map[string]interface{}{
"path": symlinkName,
"role": "unusedrole",
},
Logger: hclog.Default(),
}
if tc.configValue != "" {
config.Config["remove_jwt_after_reading"] = tc.configValue
}
config.Config["remove_jwt_follows_symlinks"] = tc.removeJWTFollowsSymlinks
jwtAuth, err := NewJWTAuthMethod(config)
if err != nil {
t.Fatal(err)
}
jwtAuth.(*jwtMethod).ingressToken()
pathToCheck := symlinkName
if tc.removeJWTFollowsSymlinks {
pathToCheck = tokenPath
}
if _, err := os.Lstat(pathToCheck); tc.shouldDelete {
if err == nil || !os.IsNotExist(err) {
t.Fatal(err)
}
} else {
if err != nil {
t.Fatal(err)
}
}
}
}

View File

@ -3,7 +3,7 @@ package agent
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"io/ioutil" "fmt"
"os" "os"
"testing" "testing"
"time" "time"
@ -24,11 +24,32 @@ import (
) )
func TestJWTEndToEnd(t *testing.T) { func TestJWTEndToEnd(t *testing.T) {
testJWTEndToEnd(t, false) t.Parallel()
testJWTEndToEnd(t, true) testCases := []struct {
ahWrapping bool
useSymlink bool
removeJWTAfterReading bool
}{
{false, false, false},
{true, false, false},
{false, true, false},
{true, true, false},
{false, false, true},
{true, false, true},
{false, true, true},
{true, true, true},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(fmt.Sprintf("ahWrapping=%v, useSymlink=%v, removeJWTAfterReading=%v", tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading), func(t *testing.T) {
t.Parallel()
testJWTEndToEnd(t, tc.ahWrapping, tc.useSymlink, tc.removeJWTAfterReading)
})
}
} }
func testJWTEndToEnd(t *testing.T, ahWrapping bool) { func testJWTEndToEnd(t *testing.T, ahWrapping, useSymlink, removeJWTAfterReading bool) {
logger := logging.NewVaultLogger(hclog.Trace) logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
Logger: logger, Logger: logger,
@ -83,16 +104,24 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
// We close these right away because we're just basically testing // We close these right away because we're just basically testing
// permissions and finding a usable file name // permissions and finding a usable file name
inf, err := ioutil.TempFile("", "auth.jwt.test.") inf, err := os.CreateTemp("", "auth.jwt.test.")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
in := inf.Name() in := inf.Name()
inf.Close() inf.Close()
os.Remove(in) os.Remove(in)
symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
if err != nil {
t.Fatal(err)
}
symlinkName := symlink.Name()
symlink.Close()
os.Remove(symlinkName)
os.Symlink(in, symlinkName)
t.Logf("input: %s", in) t.Logf("input: %s", in)
ouf, err := ioutil.TempFile("", "auth.tokensink.test.") ouf, err := os.CreateTemp("", "auth.tokensink.test.")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -101,7 +130,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
os.Remove(out) os.Remove(out)
t.Logf("output: %s", out) t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.") dhpathf, err := os.CreateTemp("", "auth.dhpath.test.")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -116,7 +145,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := ioutil.WriteFile(dhpath, mPubKey, 0o600); err != nil { if err := os.WriteFile(dhpath, mPubKey, 0o600); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
logger.Trace("wrote dh param file", "path", dhpath) logger.Trace("wrote dh param file", "path", dhpath)
@ -124,12 +153,21 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
var fileNameToUseAsPath string
if useSymlink {
fileNameToUseAsPath = symlinkName
} else {
fileNameToUseAsPath = in
}
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{ am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.jwt"), Logger: logger.Named("auth.jwt"),
MountPath: "auth/jwt", MountPath: "auth/jwt",
Config: map[string]interface{}{ Config: map[string]interface{}{
"path": in, "path": fileNameToUseAsPath,
"role": "test", "role": "test",
"remove_jwt_after_reading": removeJWTAfterReading,
"remove_jwt_follows_symlinks": true,
"jwt_read_period": "0.5s",
}, },
}) })
if err != nil { if err != nil {
@ -225,7 +263,8 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
// Get a token // Get a token
jwtToken, _ := GetTestJWT(t) jwtToken, _ := GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
logger.Trace("wrote test jwt", "path", in) logger.Trace("wrote test jwt", "path", in)
@ -237,13 +276,29 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
if time.Now().After(timeout) { if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout") t.Fatal("did not find a written token after timeout")
} }
val, err := ioutil.ReadFile(out) val, err := os.ReadFile(out)
if err == nil { if err == nil {
os.Remove(out) os.Remove(out)
if len(val) == 0 { if len(val) == 0 {
t.Fatal("written token was empty") t.Fatal("written token was empty")
} }
// First, ensure JWT has been removed
if removeJWTAfterReading {
_, err = os.Stat(in)
if err == nil {
t.Fatal("no error returned from stat, indicating the jwt is still present")
}
if !os.IsNotExist(err) {
t.Fatalf("unexpected error: %v", err)
}
} else {
_, err := os.Stat(in)
if err != nil {
t.Fatal("JWT file removed despite removeJWTAfterReading being set to false")
}
}
// First decrypt it // First decrypt it
resp := new(dhutil.Envelope) resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil { if err := jsonutil.DecodeJSON(val, resp); err != nil {
@ -336,7 +391,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
// Get another token to test the backend pushing the need to authenticate // Get another token to test the backend pushing the need to authenticate
// to the handler // to the handler
jwtToken, _ = GetTestJWT(t) jwtToken, _ = GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0o600); err != nil { if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -6,7 +6,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
@ -61,7 +60,7 @@ func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
} }
func readToken(fileName string) (*logical.HTTPWrapInfo, error) { func readToken(fileName string) (*logical.HTTPWrapInfo, error) {
b, err := ioutil.ReadFile(fileName) b, err := os.ReadFile(fileName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -18,3 +18,13 @@ method](/vault/docs/auth/jwt).
- `remove_jwt_after_reading` `(bool: optional, defaults to true)` - - `remove_jwt_after_reading` `(bool: optional, defaults to true)` -
This can be set to `false` to disable the default behavior of removing the This can be set to `false` to disable the default behavior of removing the
JWT after it's been read. JWT after it's been read.
- `remove_jwt_follows_symlinks` `(bool: optional, defaults to false)` -
This can be set to `true` to follow symlinks when removing the JWT after it has been read
when executing the `remove_jwt_after_reading` behaviour. If set to false, it will delete
the symlink, not the JWT. Does nothing if `remove_jwt_after_reading` is false.
- `jwt_read_period` `(duration: "0.5s", optional)` - The duration after which
Agent will attempt to read the JWT stored at `path`. Defaults to `1m` if
`remove_jwt_after_reading` is set to `true`, or `0.5s` otherwise.
Uses [duration format strings](/docs/concepts/duration-format).