Add Kerberos auth agent (#7999)

* add kerberos auth agent

* strip old comment

* changes from feedback

* strip appengine indirect dependency
This commit is contained in:
Becca Petrin 2020-01-09 14:56:34 -08:00 committed by GitHub
parent 2c6be02579
commit c2894b8d05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 475 additions and 65 deletions

View File

@ -603,7 +603,7 @@ func (c *Client) ClearToken() {
}
// Headers gets the current set of headers used for requests. This returns a
// copy; to modify it make modifications locally and use SetHeaders.
// copy; to modify it call AddHeader or SetHeaders.
func (c *Client) Headers() http.Header {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header {
return ret
}
// SetHeaders sets the headers to be used for future requests.
// AddHeader allows a single header key/value pair to be added
// in a race-safe fashion.
func (c *Client) AddHeader(key, value string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers.Add(key, value)
}
// SetHeaders clears all previous headers and uses only the given
// ones going forward.
func (c *Client) SetHeaders(headers http.Header) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers = headers
}
@ -680,8 +688,8 @@ func (c *Client) SetPolicyOverride(override bool) {
// portMap defines the standard port map
var portMap = map[string]string{
"http": "80",
"https": "443",
"http": "80",
"https": "443",
}
// NewRequest creates a new raw request object to query the Vault server
@ -703,7 +711,7 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// Avoid lookup of SRV record if scheme is known
port, ok := portMap[addr.Scheme]
if ok {
host = net.JoinHostPort(host, port)
host = net.JoinHostPort(host, port)
} else {
// Internet Draft specifies that the SRV record is ignored if a port is given
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())

View File

@ -325,3 +325,52 @@ func TestClone(t *testing.T) {
_ = client2
}
func TestSetHeadersRaceSafe(t *testing.T) {
client, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
start := make(chan interface{})
done := make(chan interface{})
testPairs := map[string]string{
"soda": "rootbeer",
"veggie": "carrots",
"fruit": "apples",
"color": "red",
"protein": "egg",
}
for key, value := range testPairs {
tmpKey := key
tmpValue := value
go func() {
<-start
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
// headerCopy := client.Header()
// headerCopy.AddHeader(tmpKey, tmpValue)
// client.SetHeader(headerCopy)
client.AddHeader(tmpKey, tmpValue)
done <- true
}()
}
// Start everyone at once.
close(start)
// Wait until everyone is done.
for i := 0; i < len(testPairs); i++ {
<-done
}
// Check that all the test pairs are in the resulting
// headers.
resultingHeaders := client.Headers()
for key, value := range testPairs {
if resultingHeaders.Get(key) != value {
t.Fatal("expected " + value + " for " + key)
}
}
}

View File

@ -1002,7 +1002,7 @@ func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
/* ARN of an AWS role that Vault can query during testing.
This role should exist in your current AWS account and your credentials
should have iam:GetRole permissions to query it.
*/
*/
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
if assumableRoleArn == "" {
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")

View File

@ -13,7 +13,6 @@ import (
"encoding/base64"
"encoding/pem"
"fmt"
"github.com/go-test/deep"
"math"
"math/big"
mathrand "math/rand"
@ -29,6 +28,7 @@ import (
"time"
"github.com/fatih/structs"
"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
vaulthttp "github.com/hashicorp/vault/http"

View File

@ -27,6 +27,7 @@ import (
"github.com/hashicorp/vault/command/agent/auth/cf"
"github.com/hashicorp/vault/command/agent/auth/gcp"
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kerberos"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
agentConfig "github.com/hashicorp/vault/command/agent/config"
@ -385,6 +386,8 @@ func (c *AgentCommand) Run(args []string) int {
method, err = gcp.NewGCPAuthMethod(authConfig)
case "jwt":
method, err = jwt.NewJWTAuthMethod(authConfig)
case "kerberos":
method, err = kerberos.NewKerberosAuthMethod(authConfig)
case "kubernetes":
method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
case "approle":

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"reflect"
"sync"
"time"
@ -174,16 +175,16 @@ type alicloudMethod struct {
stopCh chan struct{}
}
func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) {
func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) {
a.credLock.Lock()
defer a.credLock.Unlock()
a.logger.Trace("beginning authentication")
data, err := tools.GenerateLoginData(a.role, a.lastCreds, a.region)
if err != nil {
return "", nil, err
return "", nil, nil, err
}
return fmt.Sprintf("%s/login", a.mountPath), data, nil
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
}
func (a *alicloudMethod) NewCreds() chan struct{} {

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
@ -87,18 +88,18 @@ func NewApproleAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return a, nil
}
func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
if _, err := os.Stat(a.roleIDFilePath); err == nil {
roleID, err := ioutil.ReadFile(a.roleIDFilePath)
if err != nil {
if a.cachedRoleID == "" {
return "", nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err)
}
a.logger.Warn("error reading role ID file", "error", err)
}
if len(roleID) == 0 {
if a.cachedRoleID == "" {
return "", nil, errors.New("role ID file empty and no cached role ID known")
return "", nil, nil, errors.New("role ID file empty and no cached role ID known")
}
a.logger.Warn("role ID file exists but read empty value, re-using cached value")
} else {
@ -107,11 +108,11 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
}
if a.cachedRoleID == "" {
return "", nil, errors.New("no known role ID")
return "", nil, nil, errors.New("no known role ID")
}
if a.secretIDFilePath == "" {
return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{
return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{
"role_id": a.cachedRoleID,
}, nil
}
@ -120,13 +121,13 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
secretID, err := ioutil.ReadFile(a.secretIDFilePath)
if err != nil {
if a.cachedSecretID == "" {
return "", nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err)
}
a.logger.Warn("error reading secret ID file", "error", err)
}
if len(secretID) == 0 {
if a.cachedSecretID == "" {
return "", nil, errors.New("secret ID file empty and no cached secret ID known")
return "", nil, nil, errors.New("secret ID file empty and no cached secret ID known")
}
a.logger.Warn("secret ID file exists but read empty value, re-using cached value")
} else {
@ -134,50 +135,50 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
if a.secretIDResponseWrappingPath != "" {
clonedClient, err := client.Clone()
if err != nil {
return "", nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err)
}
clonedClient.SetToken(stringSecretID)
// Validate the creation path
resp, err := clonedClient.Logical().Read("sys/wrapping/lookup")
if err != nil {
return "", nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err)
}
if resp == nil {
return "", nil, errors.New("response nil when looking up wrapped secret ID")
return "", nil, nil, errors.New("response nil when looking up wrapped secret ID")
}
if resp.Data == nil {
return "", nil, errors.New("data in response nil when looking up wrapped secret ID")
return "", nil, nil, errors.New("data in response nil when looking up wrapped secret ID")
}
creationPathRaw, ok := resp.Data["creation_path"]
if !ok {
return "", nil, errors.New("creation_path in response nil when looking up wrapped secret ID")
return "", nil, nil, errors.New("creation_path in response nil when looking up wrapped secret ID")
}
creationPath, ok := creationPathRaw.(string)
if !ok {
return "", nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID")
return "", nil, nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID")
}
if creationPath != a.secretIDResponseWrappingPath {
a.logger.Error("SECURITY: unable to validate wrapping token creation path", "expected", a.secretIDResponseWrappingPath, "found", creationPath)
return "", nil, errors.New("unable to validate wrapping token creation path")
return "", nil, nil, errors.New("unable to validate wrapping token creation path")
}
// Now get the secret ID
resp, err = clonedClient.Logical().Unwrap("")
if err != nil {
return "", nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err)
}
if resp == nil {
return "", nil, errors.New("response nil when unwrapping secret ID")
return "", nil, nil, errors.New("response nil when unwrapping secret ID")
}
if resp.Data == nil {
return "", nil, errors.New("data in response nil when unwrapping secret ID")
return "", nil, nil, errors.New("data in response nil when unwrapping secret ID")
}
secretIDRaw, ok := resp.Data["secret_id"]
if !ok {
return "", nil, errors.New("secret_id in response nil when unwrapping secret ID")
return "", nil, nil, errors.New("secret_id in response nil when unwrapping secret ID")
}
secretID, ok := secretIDRaw.(string)
if !ok {
return "", nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID")
return "", nil, nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID")
}
stringSecretID = secretID
}
@ -191,10 +192,10 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s
}
if a.cachedSecretID == "" {
return "", nil, errors.New("no known secret ID")
return "", nil, nil, errors.New("no known secret ID")
}
return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{
return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{
"role_id": a.cachedRoleID,
"secret_id": a.cachedSecretID,
}, nil

View File

@ -3,6 +3,7 @@ package auth
import (
"context"
"math/rand"
"net/http"
"time"
hclog "github.com/hashicorp/go-hclog"
@ -11,7 +12,9 @@ import (
)
type AuthMethod interface {
Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
// Authenticate returns a mount path, header, request body, and error.
// The header may be nil if no special header is needed.
Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
NewCreds() chan struct{}
CredSuccess()
Shutdown()
@ -119,7 +122,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
ah.logger.Info("authenticating")
path, data, err := am.Authenticate(ctx, ah.client)
path, header, data, err := am.Authenticate(ctx, ah.client)
if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
@ -139,6 +142,11 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
})
clientToUse = wrapClient
}
for key, values := range header {
for _, value := range values {
clientToUse.AddHeader(key, value)
}
}
secret, err := clientToUse.Logical().Write(path, data)
// Check errors/sanity

View File

@ -2,6 +2,7 @@ package auth
import (
"context"
"net/http"
"testing"
"time"
@ -31,14 +32,14 @@ func newUserpassTestMethod(t *testing.T, client *api.Client) AuthMethod {
return &userpassTestMethod{}
}
func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
_, err := client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
return "", nil, err
return "", nil, nil, err
}
return "auth/userpass/login/foo", map[string]interface{}{
return "auth/userpass/login/foo", nil, map[string]interface{}{
"password": "bar",
}, nil
}

View File

@ -179,7 +179,7 @@ func NewAWSAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return a, nil
}
func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, retData map[string]interface{}, retErr error) {
func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, header http.Header, retData map[string]interface{}, retErr error) {
a.logger.Trace("beginning authentication")
data := make(map[string]interface{})
@ -266,7 +266,7 @@ func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retTo
data["role"] = a.role
return fmt.Sprintf("%s/login", a.mountPath), data, nil
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
}
func (a *awsMethod) NewCreds() chan struct{} {

View File

@ -74,7 +74,7 @@ func NewAzureAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return a, nil
}
func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) {
a.logger.Trace("beginning authentication")
// Fetch instance data
@ -126,7 +126,7 @@ func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (ret
"jwt": identity.AccessToken,
}
return fmt.Sprintf("%s/login", a.mountPath), data, nil
return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil
}
func (a *azureMethod) NewCreds() chan struct{} {

View File

@ -4,8 +4,9 @@ import (
"context"
"errors"
"fmt"
"net/http"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
)
@ -44,7 +45,7 @@ func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return c, nil
}
func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
c.logger.Trace("beginning authentication")
authMap := map[string]interface{}{}
@ -53,7 +54,7 @@ func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string
authMap["name"] = c.name
}
return fmt.Sprintf("%s/login", c.mountPath), authMap, nil
return fmt.Sprintf("%s/login", c.mountPath), nil, authMap, nil
}
func (c *certMethod) NewCreds() chan struct{} {

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"time"
@ -41,18 +42,18 @@ func NewCFAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return a, nil
}
func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
pathToClientCert := os.Getenv(cf.EnvVarInstanceCertificate)
if pathToClientCert == "" {
return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate)
return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate)
}
certBytes, err := ioutil.ReadFile(pathToClientCert)
if err != nil {
return "", nil, err
return "", nil, nil, err
}
pathToClientKey := os.Getenv(cf.EnvVarInstanceKey)
if pathToClientKey == "" {
return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey)
return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey)
}
signingTime := time.Now().UTC()
signatureData := &signatures.SignatureData{
@ -62,7 +63,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string
}
signature, err := signatures.Sign(pathToClientKey, signatureData)
if err != nil {
return "", nil, err
return "", nil, nil, err
}
data := map[string]interface{}{
"role": p.roleName,
@ -70,7 +71,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string
"signing_time": signingTime.Format(signatures.TimeFormat),
"signature": signature,
}
return fmt.Sprintf("%s/login", p.mountPath), data, nil
return fmt.Sprintf("%s/login", p.mountPath), nil, data, nil
}
func (p *cfMethod) NewCreds() chan struct{} {

View File

@ -116,7 +116,7 @@ func NewGCPAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return g, nil
}
func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) {
g.logger.Trace("beginning authentication")
data := make(map[string]interface{})
@ -227,7 +227,7 @@ func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPa
data["role"] = g.role
data["jwt"] = jwt
return fmt.Sprintf("%s/login", g.mountPath), data, nil
return fmt.Sprintf("%s/login", g.mountPath), nil, data, nil
}
func (g *gcpMethod) NewCreds() chan struct{} {

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"sync"
"sync/atomic"
@ -85,17 +86,17 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return j, nil
}
func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
j.logger.Trace("beginning authentication")
j.ingressToken()
latestToken := j.latestToken.Load().(string)
if latestToken == "" {
return "", nil, errors.New("latest known jwt is empty, cannot authenticate")
return "", nil, nil, errors.New("latest known jwt is empty, cannot authenticate")
}
return fmt.Sprintf("%s/login", j.mountPath), map[string]interface{}{
return fmt.Sprintf("%s/login", j.mountPath), nil, map[string]interface{}{
"role": j.role,
"jwt": latestToken,
}, nil

View File

@ -0,0 +1,170 @@
#!/bin/bash
# Instructions
# This integration test is for the Vault Kerberos agent.
# Before running, execute:
# pip install --quiet requests-kerberos
# Then run this test from Vault's home directory.
# ./command/agent/auth/kerberos/integtest/integrationtest.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
base64cmd="base64 -D"
else
base64cmd="base64 -d"
fi
VAULT_PORT=8200
SAMBA_VER=4.8.12
export VAULT_TOKEN=${VAULT_TOKEN:-myroot}
DOMAIN_ADMIN_PASS=Pa55word!
DOMAIN_VAULT_ACCOUNT=vault_svc
DOMAIN_VAULT_PASS=vaultPa55word!
DOMAIN_USER_ACCOUNT=grace
DOMAIN_USER_PASS=gracePa55word!
SAMBA_CONF_FILE=/srv/etc/smb.conf
DOMAIN_NAME=matrix
DNS_NAME=host
REALM_NAME=MATRIX.LAN
DOMAIN_DN=DC=MATRIX,DC=LAN
TESTS_DIR=/tmp/vault_plugin_tests
function add_user() {
username="${1}"
password="${2}"
if [[ $(check_user ${username}) -eq 0 ]]
then
echo "add user '${username}'"
docker exec $SAMBA_CONTAINER \
/usr/bin/samba-tool user create \
${username} \
${password}\
--configfile=${SAMBA_CONF_FILE}
fi
}
function check_user() {
username="${1}"
docker exec $SAMBA_CONTAINER \
/usr/bin/samba-tool user list \
--configfile=${SAMBA_CONF_FILE} \
| grep -c ${username}
}
function create_keytab() {
username="${1}"
password="${2}"
user_kvno=$(docker exec $SAMBA_CONTAINER \
bash -c "ldapsearch -H ldaps://localhost -D \"Administrator@${REALM_NAME}\" -w \"${DOMAIN_ADMIN_PASS}\" -b \"CN=Users,${DOMAIN_DN}\" -LLL \"(&(objectClass=user)(sAMAccountName=${username}))\" msDS-KeyVersionNumber | sed -n 's/^[ \t]*msDS-KeyVersionNumber:[ \t]*\(.*\)/\1/p'")
docker exec $SAMBA_CONTAINER \
bash -c "printf \"%b\" \"addent -password -p \"${username}@${REALM_NAME}\" -k ${user_kvno} -e rc4-hmac\n${password}\nwrite_kt ${username}.keytab\" | ktutil"
docker exec $SAMBA_CONTAINER \
bash -c "printf \"%b\" \"read_kt ${username}.keytab\nlist\" | ktutil"
docker exec $SAMBA_CONTAINER \
base64 ${username}.keytab > ${TESTS_DIR}/integration/${username}.keytab.base64
docker cp $SAMBA_CONTAINER:/${username}.keytab ${TESTS_DIR}/integration/
}
function main() {
# make and start vault
make dev
vault server -dev -dev-root-token-id=root &
# start our domain controller
SAMBA_CONTAINER=$(docker run --net=${DNS_NAME} -d -ti --privileged -e "SAMBA_DC_ADMIN_PASSWD=${DOMAIN_ADMIN_PASS}" -e "KERBEROS_PASSWORD=${DOMAIN_ADMIN_PASS}" -e SAMBA_DC_DOMAIN=${DOMAIN_NAME} -e SAMBA_DC_REALM=${REALM_NAME} "bodsch/docker-samba4:${SAMBA_VER}")
sleep 15
# set up users
add_user $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS
create_keytab $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS
add_user $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS
create_keytab $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS
# add the service principals we'll need
docker exec $SAMBA_CONTAINER \
samba-tool spn add HTTP/localhost ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
docker exec $SAMBA_CONTAINER \
samba-tool spn add HTTP/localhost:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
docker exec $SAMBA_CONTAINER \
samba-tool spn add HTTP/localhost.${DNS_NAME} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
docker exec $SAMBA_CONTAINER \
samba-tool spn add HTTP/localhost.${DNS_NAME}:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE}
# enable and configure the kerberos plugin in Vault
vault auth enable -passthrough-request-headers=Authorization -allowed-response-headers=www-authenticate kerberos
vault write auth/kerberos/config keytab=@${TESTS_DIR}/integration/vault_svc.keytab.base64 service_account="vault_svc"
vault write auth/kerberos/config/ldap binddn=${DOMAIN_VAULT_ACCOUNT}@${REALM_NAME} bindpass=${DOMAIN_VAULT_PASS} groupattr=sAMAccountName groupdn="${DOMAIN_DN}" groupfilter="(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" insecure_tls=true starttls=true userdn="CN=Users,${DOMAIN_DN}" userattr=sAMAccountName upndomain=${REALM_NAME} url=ldaps://localhost:636
mkdir -p ${TESTS_DIR}/integration
echo "
[libdefaults]
default_realm = ${REALM_NAME}
dns_lookup_realm = false
dns_lookup_kdc = true
ticket_lifetime = 24h
renew_lifetime = 7d
forwardable = true
rdns = false
preferred_preauth_types = 23
[realms]
${REALM_NAME} = {
kdc = localhost
admin_server = localhost
master_kdc = localhost
default_domain = localhost
}
" > ${TESTS_DIR}/integration/krb5.conf
echo "
auto_auth {
method \"kerberos\" {
mount_path = \"auth/kerberos\"
config = {
username = \"$DOMAIN_USER_ACCOUNT\"
service = \"HTTP/localhost:8200\"
realm = \"$REALM_NAME\"
keytab_path = \"$TESTS_DIR/integration/grace.keytab\"
krb5conf_path = \"$TESTS_DIR/integration/krb5.conf\"
}
}
sink \"file\" {
config = {
path = \"$TESTS_DIR/integration/agent-token.txt\"
}
}
}
" > ${TESTS_DIR}/integration/agent.conf
vault agent -config=${TESTS_DIR}/integration/agent.conf &
sleep 10
token=$(cat $TESTS_DIR/integration/agent-token.txt)
# clean up: kill vault and stop the docker container we started
kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault server
kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault agent
docker rm -f ${SAMBA_CONTAINER}
# a valid Vault token starts with "s.", check for that
if [[ $token != s.* ]]; then
echo "received invalid token: $token"
return 1
fi
echo "vault kerberos agent obtained auth token: $token"
echo "exiting successfully!"
return 0
}
main

View File

@ -0,0 +1,91 @@
package kerberos
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/gokrb5/spnego"
kerberos "github.com/hashicorp/vault-plugin-auth-kerberos"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
)
type kerberosMethod struct {
logger hclog.Logger
mountPath string
loginCfg *kerberos.LoginCfg
}
func NewKerberosAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
username, err := read("username", conf.Config)
if err != nil {
return nil, err
}
service, err := read("service", conf.Config)
if err != nil {
return nil, err
}
realm, err := read("realm", conf.Config)
if err != nil {
return nil, err
}
keytabPath, err := read("keytab_path", conf.Config)
if err != nil {
return nil, err
}
krb5ConfPath, err := read("krb5conf_path", conf.Config)
if err != nil {
return nil, err
}
return &kerberosMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
loginCfg: &kerberos.LoginCfg{
Username: username,
Service: service,
Realm: realm,
KeytabPath: keytabPath,
Krb5ConfPath: krb5ConfPath,
},
}, nil
}
func (k *kerberosMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) {
k.logger.Trace("beginning authentication")
authHeaderVal, err := kerberos.GetAuthHeaderVal(k.loginCfg)
if err != nil {
return "", nil, nil, err
}
var header http.Header
header = make(map[string][]string)
header.Set(spnego.HTTPHeaderAuthRequest, authHeaderVal)
return k.mountPath + "/login", header, make(map[string]interface{}), nil
}
// These functions are implemented to meet the AuthHandler interface,
// but we don't need to take advantage of them.
func (k *kerberosMethod) NewCreds() chan struct{} { return nil }
func (k *kerberosMethod) CredSuccess() {}
func (k *kerberosMethod) Shutdown() {}
// read reads a key from a map and convert its value to a string.
func read(key string, m map[string]interface{}) (string, error) {
raw, ok := m[key]
if !ok {
return "", fmt.Errorf("%q is required", key)
}
v, ok := raw.(string)
if !ok {
return "", fmt.Errorf("%q must be a string", key)
}
return v, nil
}

View File

@ -0,0 +1,67 @@
package kerberos
import (
"testing"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/command/agent/auth"
)
func TestNewKerberosAuthMethod(t *testing.T) {
if _, err := NewKerberosAuthMethod(nil); err == nil {
t.Fatal("err should be returned for nil input")
}
if _, err := NewKerberosAuthMethod(&auth.AuthConfig{}); err == nil {
t.Fatal("err should be returned for nil config map")
}
authConfig := simpleAuthConfig()
delete(authConfig.Config, "username")
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
t.Fatal("err should be returned for missing username")
}
authConfig = simpleAuthConfig()
delete(authConfig.Config, "service")
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
t.Fatal("err should be returned for missing service")
}
authConfig = simpleAuthConfig()
delete(authConfig.Config, "realm")
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
t.Fatal("err should be returned for missing realm")
}
authConfig = simpleAuthConfig()
delete(authConfig.Config, "keytab_path")
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
t.Fatal("err should be returned for missing keytab_path")
}
authConfig = simpleAuthConfig()
delete(authConfig.Config, "krb5conf_path")
if _, err := NewKerberosAuthMethod(authConfig); err == nil {
t.Fatal("err should be returned for missing krb5conf_path")
}
authConfig = simpleAuthConfig()
if _, err := NewKerberosAuthMethod(authConfig); err != nil {
t.Fatal(err)
}
}
func simpleAuthConfig() *auth.AuthConfig {
return &auth.AuthConfig{
Logger: hclog.NewNullLogger(),
MountPath: "kerberos",
WrapTTL: 20,
Config: map[string]interface{}{
"username": "grace",
"service": "HTTP/05a65fad28ef.matrix.lan:8200",
"realm": "MATRIX.LAN",
"keytab_path": "grace.keytab",
"krb5conf_path": "krb5.conf",
},
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
@ -72,15 +73,15 @@ func NewKubernetesAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return k, nil
}
func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
k.logger.Trace("beginning authentication")
jwtString, err := k.readJWT()
if err != nil {
return "", nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err)
return "", nil, nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err)
}
return fmt.Sprintf("%s/login", k.mountPath), map[string]interface{}{
return fmt.Sprintf("%s/login", k.mountPath), nil, map[string]interface{}{
"role": k.role,
"jwt": jwtString,
}, nil

View File

@ -61,7 +61,7 @@ func TestKubernetesAuth_basic(t *testing.T) {
k.jwtData = tc.data
}
_, data, err := k.Authenticate(context.Background(), nil)
_, _, data, err := k.Authenticate(context.Background(), nil)
if err != nil && tc.e == nil {
t.Fatal(err)
}

View File

@ -2,7 +2,6 @@ package agent
import (
"context"
"fmt"
"io/ioutil"
"os"
"testing"
@ -77,7 +76,6 @@ func TestAWSEndToEnd(t *testing.T) {
// Retain thru the account number of the given arn and wildcard the rest.
"bound_iam_principal_arn": os.Getenv(envVarAwsTestRoleArn)[:25] + "*",
}); err != nil {
fmt.Println(err)
t.Fatal(err)
}

View File

@ -1,8 +1,9 @@
package metricsutil
import (
"github.com/hashicorp/vault/sdk/logical"
"testing"
"github.com/hashicorp/vault/sdk/logical"
)
func TestFormatFromRequest(t *testing.T) {

View File

@ -603,7 +603,7 @@ func (c *Client) ClearToken() {
}
// Headers gets the current set of headers used for requests. This returns a
// copy; to modify it make modifications locally and use SetHeaders.
// copy; to modify it call AddHeader or SetHeaders.
func (c *Client) Headers() http.Header {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header {
return ret
}
// SetHeaders sets the headers to be used for future requests.
// AddHeader allows a single header key/value pair to be added
// in a race-safe fashion.
func (c *Client) AddHeader(key, value string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers.Add(key, value)
}
// SetHeaders clears all previous headers and uses only the given
// ones going forward.
func (c *Client) SetHeaders(headers http.Header) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers = headers
}