Add Kerberos auth agent (#7999)
* add kerberos auth agent * strip old comment * changes from feedback * strip appengine indirect dependency
This commit is contained in:
parent
2c6be02579
commit
c2894b8d05
|
@ -603,7 +603,7 @@ func (c *Client) ClearToken() {
|
|||
}
|
||||
|
||||
// Headers gets the current set of headers used for requests. This returns a
|
||||
// copy; to modify it make modifications locally and use SetHeaders.
|
||||
// copy; to modify it call AddHeader or SetHeaders.
|
||||
func (c *Client) Headers() http.Header {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
|
@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header {
|
|||
return ret
|
||||
}
|
||||
|
||||
// SetHeaders sets the headers to be used for future requests.
|
||||
// AddHeader allows a single header key/value pair to be added
|
||||
// in a race-safe fashion.
|
||||
func (c *Client) AddHeader(key, value string) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
c.headers.Add(key, value)
|
||||
}
|
||||
|
||||
// SetHeaders clears all previous headers and uses only the given
|
||||
// ones going forward.
|
||||
func (c *Client) SetHeaders(headers http.Header) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.headers = headers
|
||||
}
|
||||
|
||||
|
@ -680,8 +688,8 @@ func (c *Client) SetPolicyOverride(override bool) {
|
|||
|
||||
// portMap defines the standard port map
|
||||
var portMap = map[string]string{
|
||||
"http": "80",
|
||||
"https": "443",
|
||||
"http": "80",
|
||||
"https": "443",
|
||||
}
|
||||
|
||||
// NewRequest creates a new raw request object to query the Vault server
|
||||
|
@ -703,7 +711,7 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
|
|||
// Avoid lookup of SRV record if scheme is known
|
||||
port, ok := portMap[addr.Scheme]
|
||||
if ok {
|
||||
host = net.JoinHostPort(host, port)
|
||||
host = net.JoinHostPort(host, port)
|
||||
} else {
|
||||
// Internet Draft specifies that the SRV record is ignored if a port is given
|
||||
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
|
||||
|
|
|
@ -325,3 +325,52 @@ func TestClone(t *testing.T) {
|
|||
|
||||
_ = client2
|
||||
}
|
||||
|
||||
func TestSetHeadersRaceSafe(t *testing.T) {
|
||||
client, err1 := NewClient(nil)
|
||||
if err1 != nil {
|
||||
t.Fatalf("NewClient failed: %v", err1)
|
||||
}
|
||||
|
||||
start := make(chan interface{})
|
||||
done := make(chan interface{})
|
||||
|
||||
testPairs := map[string]string{
|
||||
"soda": "rootbeer",
|
||||
"veggie": "carrots",
|
||||
"fruit": "apples",
|
||||
"color": "red",
|
||||
"protein": "egg",
|
||||
}
|
||||
|
||||
for key, value := range testPairs {
|
||||
tmpKey := key
|
||||
tmpValue := value
|
||||
go func() {
|
||||
<-start
|
||||
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
|
||||
// headerCopy := client.Header()
|
||||
// headerCopy.AddHeader(tmpKey, tmpValue)
|
||||
// client.SetHeader(headerCopy)
|
||||
client.AddHeader(tmpKey, tmpValue)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Start everyone at once.
|
||||
close(start)
|
||||
|
||||
// Wait until everyone is done.
|
||||
for i := 0; i < len(testPairs); i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Check that all the test pairs are in the resulting
|
||||
// headers.
|
||||
resultingHeaders := client.Headers()
|
||||
for key, value := range testPairs {
|
||||
if resultingHeaders.Get(key) != value {
|
||||
t.Fatal("expected " + value + " for " + key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1002,7 +1002,7 @@ func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
|
|||
/* ARN of an AWS role that Vault can query during testing.
|
||||
This role should exist in your current AWS account and your credentials
|
||||
should have iam:GetRole permissions to query it.
|
||||
*/
|
||||
*/
|
||||
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
|
||||
if assumableRoleArn == "" {
|
||||
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"github.com/go-test/deep"
|
||||
"math"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
|
@ -29,6 +28,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/vault/api"
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/hashicorp/vault/command/agent/auth/cf"
|
||||
"github.com/hashicorp/vault/command/agent/auth/gcp"
|
||||
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kerberos"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
||||
|
@ -385,6 +386,8 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
method, err = gcp.NewGCPAuthMethod(authConfig)
|
||||
case "jwt":
|
||||
method, err = jwt.NewJWTAuthMethod(authConfig)
|
||||
case "kerberos":
|
||||
method, err = kerberos.NewKerberosAuthMethod(authConfig)
|
||||
case "kubernetes":
|
||||
method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
|
||||
case "approle":
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -174,16 +175,16 @@ type alicloudMethod struct {
|
|||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) {
|
||||
func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
a.credLock.Lock()
|
||||
defer a.credLock.Unlock()
|
||||
|
||||
a.logger.Trace("beginning authentication")
|
||||
data, err := tools.GenerateLoginData(a.role, a.lastCreds, a.region)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, nil, err
|
||||
}
|
||||
return fmt.Sprintf("%s/login", a.mountPath), data, nil
|
||||
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
|
||||
}
|
||||
|
||||
func (a *alicloudMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
|
@ -87,18 +88,18 @@ func NewApproleAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
if _, err := os.Stat(a.roleIDFilePath); err == nil {
|
||||
roleID, err := ioutil.ReadFile(a.roleIDFilePath)
|
||||
if err != nil {
|
||||
if a.cachedRoleID == "" {
|
||||
return "", nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err)
|
||||
}
|
||||
a.logger.Warn("error reading role ID file", "error", err)
|
||||
}
|
||||
if len(roleID) == 0 {
|
||||
if a.cachedRoleID == "" {
|
||||
return "", nil, errors.New("role ID file empty and no cached role ID known")
|
||||
return "", nil, nil, errors.New("role ID file empty and no cached role ID known")
|
||||
}
|
||||
a.logger.Warn("role ID file exists but read empty value, re-using cached value")
|
||||
} else {
|
||||
|
@ -107,11 +108,11 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
|
|||
}
|
||||
|
||||
if a.cachedRoleID == "" {
|
||||
return "", nil, errors.New("no known role ID")
|
||||
return "", nil, nil, errors.New("no known role ID")
|
||||
}
|
||||
|
||||
if a.secretIDFilePath == "" {
|
||||
return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{
|
||||
return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{
|
||||
"role_id": a.cachedRoleID,
|
||||
}, nil
|
||||
}
|
||||
|
@ -120,13 +121,13 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
|
|||
secretID, err := ioutil.ReadFile(a.secretIDFilePath)
|
||||
if err != nil {
|
||||
if a.cachedSecretID == "" {
|
||||
return "", nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err)
|
||||
}
|
||||
a.logger.Warn("error reading secret ID file", "error", err)
|
||||
}
|
||||
if len(secretID) == 0 {
|
||||
if a.cachedSecretID == "" {
|
||||
return "", nil, errors.New("secret ID file empty and no cached secret ID known")
|
||||
return "", nil, nil, errors.New("secret ID file empty and no cached secret ID known")
|
||||
}
|
||||
a.logger.Warn("secret ID file exists but read empty value, re-using cached value")
|
||||
} else {
|
||||
|
@ -134,50 +135,50 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
|
|||
if a.secretIDResponseWrappingPath != "" {
|
||||
clonedClient, err := client.Clone()
|
||||
if err != nil {
|
||||
return "", nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err)
|
||||
}
|
||||
clonedClient.SetToken(stringSecretID)
|
||||
// Validate the creation path
|
||||
resp, err := clonedClient.Logical().Read("sys/wrapping/lookup")
|
||||
if err != nil {
|
||||
return "", nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return "", nil, errors.New("response nil when looking up wrapped secret ID")
|
||||
return "", nil, nil, errors.New("response nil when looking up wrapped secret ID")
|
||||
}
|
||||
if resp.Data == nil {
|
||||
return "", nil, errors.New("data in response nil when looking up wrapped secret ID")
|
||||
return "", nil, nil, errors.New("data in response nil when looking up wrapped secret ID")
|
||||
}
|
||||
creationPathRaw, ok := resp.Data["creation_path"]
|
||||
if !ok {
|
||||
return "", nil, errors.New("creation_path in response nil when looking up wrapped secret ID")
|
||||
return "", nil, nil, errors.New("creation_path in response nil when looking up wrapped secret ID")
|
||||
}
|
||||
creationPath, ok := creationPathRaw.(string)
|
||||
if !ok {
|
||||
return "", nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID")
|
||||
return "", nil, nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID")
|
||||
}
|
||||
if creationPath != a.secretIDResponseWrappingPath {
|
||||
a.logger.Error("SECURITY: unable to validate wrapping token creation path", "expected", a.secretIDResponseWrappingPath, "found", creationPath)
|
||||
return "", nil, errors.New("unable to validate wrapping token creation path")
|
||||
return "", nil, nil, errors.New("unable to validate wrapping token creation path")
|
||||
}
|
||||
// Now get the secret ID
|
||||
resp, err = clonedClient.Logical().Unwrap("")
|
||||
if err != nil {
|
||||
return "", nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return "", nil, errors.New("response nil when unwrapping secret ID")
|
||||
return "", nil, nil, errors.New("response nil when unwrapping secret ID")
|
||||
}
|
||||
if resp.Data == nil {
|
||||
return "", nil, errors.New("data in response nil when unwrapping secret ID")
|
||||
return "", nil, nil, errors.New("data in response nil when unwrapping secret ID")
|
||||
}
|
||||
secretIDRaw, ok := resp.Data["secret_id"]
|
||||
if !ok {
|
||||
return "", nil, errors.New("secret_id in response nil when unwrapping secret ID")
|
||||
return "", nil, nil, errors.New("secret_id in response nil when unwrapping secret ID")
|
||||
}
|
||||
secretID, ok := secretIDRaw.(string)
|
||||
if !ok {
|
||||
return "", nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID")
|
||||
return "", nil, nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID")
|
||||
}
|
||||
stringSecretID = secretID
|
||||
}
|
||||
|
@ -191,10 +192,10 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
|
|||
}
|
||||
|
||||
if a.cachedSecretID == "" {
|
||||
return "", nil, errors.New("no known secret ID")
|
||||
return "", nil, nil, errors.New("no known secret ID")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{
|
||||
return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{
|
||||
"role_id": a.cachedRoleID,
|
||||
"secret_id": a.cachedSecretID,
|
||||
}, nil
|
||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
|||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
|
@ -11,7 +12,9 @@ import (
|
|||
)
|
||||
|
||||
type AuthMethod interface {
|
||||
Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
|
||||
// Authenticate returns a mount path, header, request body, and error.
|
||||
// The header may be nil if no special header is needed.
|
||||
Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
|
||||
NewCreds() chan struct{}
|
||||
CredSuccess()
|
||||
Shutdown()
|
||||
|
@ -119,7 +122,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
|
|||
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
|
||||
|
||||
ah.logger.Info("authenticating")
|
||||
path, data, err := am.Authenticate(ctx, ah.client)
|
||||
path, header, data, err := am.Authenticate(ctx, ah.client)
|
||||
if err != nil {
|
||||
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
|
||||
backoffOrQuit(ctx, backoff)
|
||||
|
@ -139,6 +142,11 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
|
|||
})
|
||||
clientToUse = wrapClient
|
||||
}
|
||||
for key, values := range header {
|
||||
for _, value := range values {
|
||||
clientToUse.AddHeader(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
secret, err := clientToUse.Logical().Write(path, data)
|
||||
// Check errors/sanity
|
||||
|
|
|
@ -2,6 +2,7 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -31,14 +32,14 @@ func newUserpassTestMethod(t *testing.T, client *api.Client) AuthMethod {
|
|||
return &userpassTestMethod{}
|
||||
}
|
||||
|
||||
func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
_, err := client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, nil, err
|
||||
}
|
||||
return "auth/userpass/login/foo", map[string]interface{}{
|
||||
return "auth/userpass/login/foo", nil, map[string]interface{}{
|
||||
"password": "bar",
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ func NewAWSAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, retData map[string]interface{}, retErr error) {
|
||||
func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, header http.Header, retData map[string]interface{}, retErr error) {
|
||||
a.logger.Trace("beginning authentication")
|
||||
|
||||
data := make(map[string]interface{})
|
||||
|
@ -266,7 +266,7 @@ func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retTo
|
|||
|
||||
data["role"] = a.role
|
||||
|
||||
return fmt.Sprintf("%s/login", a.mountPath), data, nil
|
||||
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
|
||||
}
|
||||
|
||||
func (a *awsMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -74,7 +74,7 @@ func NewAzureAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
|
||||
func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) {
|
||||
a.logger.Trace("beginning authentication")
|
||||
|
||||
// Fetch instance data
|
||||
|
@ -126,7 +126,7 @@ func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (ret
|
|||
"jwt": identity.AccessToken,
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/login", a.mountPath), data, nil
|
||||
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
|
||||
}
|
||||
|
||||
func (a *azureMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -4,8 +4,9 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/agent/auth"
|
||||
)
|
||||
|
@ -44,7 +45,7 @@ func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
c.logger.Trace("beginning authentication")
|
||||
|
||||
authMap := map[string]interface{}{}
|
||||
|
@ -53,7 +54,7 @@ func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string
|
|||
authMap["name"] = c.name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/login", c.mountPath), authMap, nil
|
||||
return fmt.Sprintf("%s/login", c.mountPath), nil, authMap, nil
|
||||
}
|
||||
|
||||
func (c *certMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
|
@ -41,18 +42,18 @@ func NewCFAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
pathToClientCert := os.Getenv(cf.EnvVarInstanceCertificate)
|
||||
if pathToClientCert == "" {
|
||||
return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate)
|
||||
return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate)
|
||||
}
|
||||
certBytes, err := ioutil.ReadFile(pathToClientCert)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, nil, err
|
||||
}
|
||||
pathToClientKey := os.Getenv(cf.EnvVarInstanceKey)
|
||||
if pathToClientKey == "" {
|
||||
return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey)
|
||||
return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey)
|
||||
}
|
||||
signingTime := time.Now().UTC()
|
||||
signatureData := &signatures.SignatureData{
|
||||
|
@ -62,7 +63,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string
|
|||
}
|
||||
signature, err := signatures.Sign(pathToClientKey, signatureData)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, nil, err
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"role": p.roleName,
|
||||
|
@ -70,7 +71,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string
|
|||
"signing_time": signingTime.Format(signatures.TimeFormat),
|
||||
"signature": signature,
|
||||
}
|
||||
return fmt.Sprintf("%s/login", p.mountPath), data, nil
|
||||
return fmt.Sprintf("%s/login", p.mountPath), nil, data, nil
|
||||
}
|
||||
|
||||
func (p *cfMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -116,7 +116,7 @@ func NewGCPAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return g, nil
|
||||
}
|
||||
|
||||
func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
|
||||
func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) {
|
||||
g.logger.Trace("beginning authentication")
|
||||
|
||||
data := make(map[string]interface{})
|
||||
|
@ -227,7 +227,7 @@ func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPa
|
|||
data["role"] = g.role
|
||||
data["jwt"] = jwt
|
||||
|
||||
return fmt.Sprintf("%s/login", g.mountPath), data, nil
|
||||
return fmt.Sprintf("%s/login", g.mountPath), nil, data, nil
|
||||
}
|
||||
|
||||
func (g *gcpMethod) NewCreds() chan struct{} {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -85,17 +86,17 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return j, nil
|
||||
}
|
||||
|
||||
func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
j.logger.Trace("beginning authentication")
|
||||
|
||||
j.ingressToken()
|
||||
|
||||
latestToken := j.latestToken.Load().(string)
|
||||
if latestToken == "" {
|
||||
return "", nil, errors.New("latest known jwt is empty, cannot authenticate")
|
||||
return "", nil, nil, errors.New("latest known jwt is empty, cannot authenticate")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/login", j.mountPath), map[string]interface{}{
|
||||
return fmt.Sprintf("%s/login", j.mountPath), nil, map[string]interface{}{
|
||||
"role": j.role,
|
||||
"jwt": latestToken,
|
||||
}, nil
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
#!/bin/bash
|
||||
# Instructions
|
||||
# This integration test is for the Vault Kerberos agent.
|
||||
# Before running, execute:
|
||||
# pip install --quiet requests-kerberos
|
||||
# Then run this test from Vault's home directory.
|
||||
# ./command/agent/auth/kerberos/integtest/integrationtest.sh
|
||||
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
base64cmd="base64 -D"
|
||||
else
|
||||
base64cmd="base64 -d"
|
||||
fi
|
||||
|
||||
VAULT_PORT=8200
|
||||
SAMBA_VER=4.8.12
|
||||
|
||||
export VAULT_TOKEN=${VAULT_TOKEN:-myroot}
|
||||
DOMAIN_ADMIN_PASS=Pa55word!
|
||||
DOMAIN_VAULT_ACCOUNT=vault_svc
|
||||
DOMAIN_VAULT_PASS=vaultPa55word!
|
||||
DOMAIN_USER_ACCOUNT=grace
|
||||
DOMAIN_USER_PASS=gracePa55word!
|
||||
|
||||
SAMBA_CONF_FILE=/srv/etc/smb.conf
|
||||
DOMAIN_NAME=matrix
|
||||
DNS_NAME=host
|
||||
REALM_NAME=MATRIX.LAN
|
||||
DOMAIN_DN=DC=MATRIX,DC=LAN
|
||||
TESTS_DIR=/tmp/vault_plugin_tests
|
||||
|
||||
function add_user() {
|
||||
|
||||
username="${1}"
|
||||
password="${2}"
|
||||
|
||||
if [[ $(check_user ${username}) -eq 0 ]]
|
||||
then
|
||||
echo "add user '${username}'"
|
||||
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
/usr/bin/samba-tool user create \
|
||||
${username} \
|
||||
${password}\
|
||||
--configfile=${SAMBA_CONF_FILE}
|
||||
fi
|
||||
}
|
||||
|
||||
function check_user() {
|
||||
|
||||
username="${1}"
|
||||
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
/usr/bin/samba-tool user list \
|
||||
--configfile=${SAMBA_CONF_FILE} \
|
||||
| grep -c ${username}
|
||||
}
|
||||
|
||||
function create_keytab() {
|
||||
|
||||
username="${1}"
|
||||
password="${2}"
|
||||
|
||||
user_kvno=$(docker exec $SAMBA_CONTAINER \
|
||||
bash -c "ldapsearch -H ldaps://localhost -D \"Administrator@${REALM_NAME}\" -w \"${DOMAIN_ADMIN_PASS}\" -b \"CN=Users,${DOMAIN_DN}\" -LLL \"(&(objectClass=user)(sAMAccountName=${username}))\" msDS-KeyVersionNumber | sed -n 's/^[ \t]*msDS-KeyVersionNumber:[ \t]*\(.*\)/\1/p'")
|
||||
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
bash -c "printf \"%b\" \"addent -password -p \"${username}@${REALM_NAME}\" -k ${user_kvno} -e rc4-hmac\n${password}\nwrite_kt ${username}.keytab\" | ktutil"
|
||||
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
bash -c "printf \"%b\" \"read_kt ${username}.keytab\nlist\" | ktutil"
|
||||
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
base64 ${username}.keytab > ${TESTS_DIR}/integration/${username}.keytab.base64
|
||||
|
||||
docker cp $SAMBA_CONTAINER:/${username}.keytab ${TESTS_DIR}/integration/
|
||||
}
|
||||
|
||||
function main() {
|
||||
# make and start vault
|
||||
make dev
|
||||
vault server -dev -dev-root-token-id=root &
|
||||
|
||||
# start our domain controller
|
||||
SAMBA_CONTAINER=$(docker run --net=${DNS_NAME} -d -ti --privileged -e "SAMBA_DC_ADMIN_PASSWD=${DOMAIN_ADMIN_PASS}" -e "KERBEROS_PASSWORD=${DOMAIN_ADMIN_PASS}" -e SAMBA_DC_DOMAIN=${DOMAIN_NAME} -e SAMBA_DC_REALM=${REALM_NAME} "bodsch/docker-samba4:${SAMBA_VER}")
|
||||
sleep 15
|
||||
|
||||
# set up users
|
||||
add_user $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS
|
||||
create_keytab $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS
|
||||
|
||||
add_user $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS
|
||||
create_keytab $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS
|
||||
|
||||
# add the service principals we'll need
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
samba-tool spn add HTTP/localhost ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
samba-tool spn add HTTP/localhost:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
samba-tool spn add HTTP/localhost.${DNS_NAME} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
|
||||
docker exec $SAMBA_CONTAINER \
|
||||
samba-tool spn add HTTP/localhost.${DNS_NAME}:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
|
||||
|
||||
# enable and configure the kerberos plugin in Vault
|
||||
vault auth enable -passthrough-request-headers=Authorization -allowed-response-headers=www-authenticate kerberos
|
||||
vault write auth/kerberos/config keytab=@${TESTS_DIR}/integration/vault_svc.keytab.base64 service_account="vault_svc"
|
||||
vault write auth/kerberos/config/ldap binddn=${DOMAIN_VAULT_ACCOUNT}@${REALM_NAME} bindpass=${DOMAIN_VAULT_PASS} groupattr=sAMAccountName groupdn="${DOMAIN_DN}" groupfilter="(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" insecure_tls=true starttls=true userdn="CN=Users,${DOMAIN_DN}" userattr=sAMAccountName upndomain=${REALM_NAME} url=ldaps://localhost:636
|
||||
|
||||
mkdir -p ${TESTS_DIR}/integration
|
||||
|
||||
echo "
|
||||
[libdefaults]
|
||||
default_realm = ${REALM_NAME}
|
||||
dns_lookup_realm = false
|
||||
dns_lookup_kdc = true
|
||||
ticket_lifetime = 24h
|
||||
renew_lifetime = 7d
|
||||
forwardable = true
|
||||
rdns = false
|
||||
preferred_preauth_types = 23
|
||||
[realms]
|
||||
${REALM_NAME} = {
|
||||
kdc = localhost
|
||||
admin_server = localhost
|
||||
master_kdc = localhost
|
||||
default_domain = localhost
|
||||
}
|
||||
" > ${TESTS_DIR}/integration/krb5.conf
|
||||
|
||||
echo "
|
||||
auto_auth {
|
||||
method \"kerberos\" {
|
||||
mount_path = \"auth/kerberos\"
|
||||
config = {
|
||||
username = \"$DOMAIN_USER_ACCOUNT\"
|
||||
service = \"HTTP/localhost:8200\"
|
||||
realm = \"$REALM_NAME\"
|
||||
keytab_path = \"$TESTS_DIR/integration/grace.keytab\"
|
||||
krb5conf_path = \"$TESTS_DIR/integration/krb5.conf\"
|
||||
}
|
||||
}
|
||||
sink \"file\" {
|
||||
config = {
|
||||
path = \"$TESTS_DIR/integration/agent-token.txt\"
|
||||
}
|
||||
}
|
||||
}
|
||||
" > ${TESTS_DIR}/integration/agent.conf
|
||||
|
||||
vault agent -config=${TESTS_DIR}/integration/agent.conf &
|
||||
sleep 10
|
||||
token=$(cat $TESTS_DIR/integration/agent-token.txt)
|
||||
|
||||
# clean up: kill vault and stop the docker container we started
|
||||
kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault server
|
||||
kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault agent
|
||||
docker rm -f ${SAMBA_CONTAINER}
|
||||
|
||||
# a valid Vault token starts with "s.", check for that
|
||||
if [[ $token != s.* ]]; then
|
||||
echo "received invalid token: $token"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "vault kerberos agent obtained auth token: $token"
|
||||
echo "exiting successfully!"
|
||||
return 0
|
||||
}
|
||||
main
|
|
@ -0,0 +1,91 @@
|
|||
package kerberos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/gokrb5/spnego"
|
||||
kerberos "github.com/hashicorp/vault-plugin-auth-kerberos"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/agent/auth"
|
||||
)
|
||||
|
||||
type kerberosMethod struct {
|
||||
logger hclog.Logger
|
||||
mountPath string
|
||||
loginCfg *kerberos.LoginCfg
|
||||
}
|
||||
|
||||
func NewKerberosAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("empty config")
|
||||
}
|
||||
if conf.Config == nil {
|
||||
return nil, errors.New("empty config data")
|
||||
}
|
||||
username, err := read("username", conf.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service, err := read("service", conf.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realm, err := read("realm", conf.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keytabPath, err := read("keytab_path", conf.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
krb5ConfPath, err := read("krb5conf_path", conf.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kerberosMethod{
|
||||
logger: conf.Logger,
|
||||
mountPath: conf.MountPath,
|
||||
loginCfg: &kerberos.LoginCfg{
|
||||
Username: username,
|
||||
Service: service,
|
||||
Realm: realm,
|
||||
KeytabPath: keytabPath,
|
||||
Krb5ConfPath: krb5ConfPath,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *kerberosMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
k.logger.Trace("beginning authentication")
|
||||
authHeaderVal, err := kerberos.GetAuthHeaderVal(k.loginCfg)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
var header http.Header
|
||||
header = make(map[string][]string)
|
||||
header.Set(spnego.HTTPHeaderAuthRequest, authHeaderVal)
|
||||
return k.mountPath + "/login", header, make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
// These functions are implemented to meet the AuthHandler interface,
|
||||
// but we don't need to take advantage of them.
|
||||
func (k *kerberosMethod) NewCreds() chan struct{} { return nil }
|
||||
func (k *kerberosMethod) CredSuccess() {}
|
||||
func (k *kerberosMethod) Shutdown() {}
|
||||
|
||||
// read reads a key from a map and convert its value to a string.
|
||||
func read(key string, m map[string]interface{}) (string, error) {
|
||||
raw, ok := m[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%q is required", key)
|
||||
}
|
||||
v, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%q must be a string", key)
|
||||
}
|
||||
return v, nil
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package kerberos
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/command/agent/auth"
|
||||
)
|
||||
|
||||
func TestNewKerberosAuthMethod(t *testing.T) {
|
||||
if _, err := NewKerberosAuthMethod(nil); err == nil {
|
||||
t.Fatal("err should be returned for nil input")
|
||||
}
|
||||
if _, err := NewKerberosAuthMethod(&auth.AuthConfig{}); err == nil {
|
||||
t.Fatal("err should be returned for nil config map")
|
||||
}
|
||||
|
||||
authConfig := simpleAuthConfig()
|
||||
delete(authConfig.Config, "username")
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
|
||||
t.Fatal("err should be returned for missing username")
|
||||
}
|
||||
|
||||
authConfig = simpleAuthConfig()
|
||||
delete(authConfig.Config, "service")
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
|
||||
t.Fatal("err should be returned for missing service")
|
||||
}
|
||||
|
||||
authConfig = simpleAuthConfig()
|
||||
delete(authConfig.Config, "realm")
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
|
||||
t.Fatal("err should be returned for missing realm")
|
||||
}
|
||||
|
||||
authConfig = simpleAuthConfig()
|
||||
delete(authConfig.Config, "keytab_path")
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
|
||||
t.Fatal("err should be returned for missing keytab_path")
|
||||
}
|
||||
|
||||
authConfig = simpleAuthConfig()
|
||||
delete(authConfig.Config, "krb5conf_path")
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
|
||||
t.Fatal("err should be returned for missing krb5conf_path")
|
||||
}
|
||||
|
||||
authConfig = simpleAuthConfig()
|
||||
if _, err := NewKerberosAuthMethod(authConfig); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func simpleAuthConfig() *auth.AuthConfig {
|
||||
return &auth.AuthConfig{
|
||||
Logger: hclog.NewNullLogger(),
|
||||
MountPath: "kerberos",
|
||||
WrapTTL: 20,
|
||||
Config: map[string]interface{}{
|
||||
"username": "grace",
|
||||
"service": "HTTP/05a65fad28ef.matrix.lan:8200",
|
||||
"realm": "MATRIX.LAN",
|
||||
"keytab_path": "grace.keytab",
|
||||
"krb5conf_path": "krb5.conf",
|
||||
},
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
|
@ -72,15 +73,15 @@ func NewKubernetesAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
|
|||
return k, nil
|
||||
}
|
||||
|
||||
func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
|
||||
func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
|
||||
k.logger.Trace("beginning authentication")
|
||||
|
||||
jwtString, err := k.readJWT()
|
||||
if err != nil {
|
||||
return "", nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err)
|
||||
return "", nil, nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/login", k.mountPath), map[string]interface{}{
|
||||
return fmt.Sprintf("%s/login", k.mountPath), nil, map[string]interface{}{
|
||||
"role": k.role,
|
||||
"jwt": jwtString,
|
||||
}, nil
|
||||
|
|
|
@ -61,7 +61,7 @@ func TestKubernetesAuth_basic(t *testing.T) {
|
|||
k.jwtData = tc.data
|
||||
}
|
||||
|
||||
_, data, err := k.Authenticate(context.Background(), nil)
|
||||
_, _, data, err := k.Authenticate(context.Background(), nil)
|
||||
if err != nil && tc.e == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package agent
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -77,7 +76,6 @@ func TestAWSEndToEnd(t *testing.T) {
|
|||
// Retain thru the account number of the given arn and wildcard the rest.
|
||||
"bound_iam_principal_arn": os.Getenv(envVarAwsTestRoleArn)[:25] + "*",
|
||||
}); err != nil {
|
||||
fmt.Println(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package metricsutil
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func TestFormatFromRequest(t *testing.T) {
|
||||
|
|
|
@ -603,7 +603,7 @@ func (c *Client) ClearToken() {
|
|||
}
|
||||
|
||||
// Headers gets the current set of headers used for requests. This returns a
|
||||
// copy; to modify it make modifications locally and use SetHeaders.
|
||||
// copy; to modify it call AddHeader or SetHeaders.
|
||||
func (c *Client) Headers() http.Header {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
|
@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header {
|
|||
return ret
|
||||
}
|
||||
|
||||
// SetHeaders sets the headers to be used for future requests.
|
||||
// AddHeader allows a single header key/value pair to be added
|
||||
// in a race-safe fashion.
|
||||
func (c *Client) AddHeader(key, value string) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
c.headers.Add(key, value)
|
||||
}
|
||||
|
||||
// SetHeaders clears all previous headers and uses only the given
|
||||
// ones going forward.
|
||||
func (c *Client) SetHeaders(headers http.Header) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.headers = headers
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue