d710f8e8dc
The SSH secrets engine previously split the `validPrincipals` field on comma, then if user templating is enabled, evaluated the templates on each substring. This meant the identity template was only ever allowed to return a single principal. There are use cases where it would be helpful for identity metadata to contain a list of valid principals and for the identity template to be able to inject all of those as valid principals. This change inverts the order of processing. First the template is evaluated, and then the resulting string is split on commas. This allows the identity template to return a single comma-separated string with multiple permitted principals. There is a potential security implication here, that if a user is allowed to update their own identity metadata, they may be able to elevate privileges where previously this was not possible. Fixes #11038
559 lines
17 KiB
Go
559 lines
17 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"crypto/dsa"
|
|
"crypto/ecdsa"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
|
"github.com/hashicorp/vault/sdk/framework"
|
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
|
"github.com/hashicorp/vault/sdk/helper/strutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var containsTemplateRegex = regexp.MustCompile(`{{.+?}}`)
|
|
|
|
var ecCurveBitsToAlgoName = map[int]string{
|
|
256: ssh.KeyAlgoECDSA256,
|
|
384: ssh.KeyAlgoECDSA384,
|
|
521: ssh.KeyAlgoECDSA521,
|
|
}
|
|
|
|
// If the algorithm is not found, it could be that we have a curve
|
|
// that we haven't added a constant for yet. But they could allow it
|
|
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
|
|
// mapping rather than using a named SSH key type, so erring out here
|
|
// isn't advisable.
|
|
|
|
type creationBundle struct {
|
|
KeyID string
|
|
ValidPrincipals []string
|
|
PublicKey ssh.PublicKey
|
|
CertificateType uint32
|
|
TTL time.Duration
|
|
Signer ssh.Signer
|
|
Role *sshRole
|
|
CriticalOptions map[string]string
|
|
Extensions map[string]string
|
|
}
|
|
|
|
func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, publicKey ssh.PublicKey) (*logical.Response, error) {
|
|
// Note that these various functions always return "user errors" so we pass
|
|
// them as 4xx values
|
|
keyID, err := b.calculateKeyID(data, req, role, publicKey)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
certificateType, err := b.calculateCertificateType(data, role)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
var parsedPrincipals []string
|
|
if certificateType == ssh.HostCert {
|
|
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, role.AllowedDomainsTemplate, validateValidPrincipalForHosts(role))
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
} else {
|
|
defaultPrincipal := role.DefaultUser
|
|
if role.DefaultUserTemplate {
|
|
defaultPrincipal, err = b.renderPrincipal(role.DefaultUser, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, defaultPrincipal, role.AllowedUsers, role.AllowedUsersTemplate, strutil.StrListContains)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
}
|
|
|
|
ttl, err := b.calculateTTL(data, role)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
criticalOptions, err := b.calculateCriticalOptions(data, role)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
extensions, addExtTemplatingWarning, err := b.calculateExtensions(data, req, role)
|
|
if err != nil {
|
|
return logical.ErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read CA private key: %w", err)
|
|
}
|
|
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
|
|
return nil, errors.New("failed to read CA private key")
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
|
|
}
|
|
|
|
cBundle := creationBundle{
|
|
KeyID: keyID,
|
|
PublicKey: publicKey,
|
|
Signer: signer,
|
|
ValidPrincipals: parsedPrincipals,
|
|
TTL: ttl,
|
|
CertificateType: certificateType,
|
|
Role: role,
|
|
CriticalOptions: criticalOptions,
|
|
Extensions: extensions,
|
|
}
|
|
|
|
certificate, err := cBundle.sign()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
|
|
if len(signedSSHCertificate) == 0 {
|
|
return nil, errors.New("error marshaling signed certificate")
|
|
}
|
|
|
|
response := &logical.Response{
|
|
Data: map[string]interface{}{
|
|
"serial_number": strconv.FormatUint(certificate.Serial, 16),
|
|
"signed_key": string(signedSSHCertificate),
|
|
},
|
|
}
|
|
|
|
if addExtTemplatingWarning {
|
|
response.AddWarning("default_extension templating enabled with at least one extension requiring identity templating. However, this request lacked identity entity information, causing one or more extensions to be skipped from the generated certificate.")
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (b *backend) renderPrincipal(principal string, req *logical.Request) (string, error) {
|
|
// Look for templating markers {{ .* }}
|
|
matched := containsTemplateRegex.MatchString(principal)
|
|
if matched {
|
|
if req.EntityID != "" {
|
|
// Retrieve principal based on template + entityID from request.
|
|
renderedPrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
|
|
if err != nil {
|
|
return "", fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
|
|
}
|
|
return renderedPrincipal, nil
|
|
}
|
|
}
|
|
// Static principal
|
|
return principal, nil
|
|
}
|
|
|
|
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, enableTemplating bool, validatePrincipal func([]string, string) bool) ([]string, error) {
|
|
validPrincipals := ""
|
|
validPrincipalsRaw, ok := data.GetOk("valid_principals")
|
|
if ok {
|
|
validPrincipals = validPrincipalsRaw.(string)
|
|
} else {
|
|
validPrincipals = defaultPrincipal
|
|
}
|
|
|
|
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
|
|
// Build list of allowed Principals from template and static principalsAllowedByRole
|
|
var allowedPrincipals []string
|
|
if enableTemplating {
|
|
rendered, err := b.renderPrincipal(principalsAllowedByRole, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allowedPrincipals = strutil.RemoveDuplicates(strutil.ParseStringSlice(rendered, ","), false)
|
|
} else {
|
|
allowedPrincipals = strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false)
|
|
}
|
|
|
|
switch {
|
|
case len(parsedPrincipals) == 0:
|
|
// There is nothing to process
|
|
return nil, nil
|
|
case len(allowedPrincipals) == 0:
|
|
// User has requested principals to be set, but role is not configured
|
|
// with any principals
|
|
return nil, fmt.Errorf("role is not configured to allow any principals")
|
|
default:
|
|
// Role was explicitly configured to allow any principal.
|
|
if principalsAllowedByRole == "*" {
|
|
return parsedPrincipals, nil
|
|
}
|
|
|
|
for _, principal := range parsedPrincipals {
|
|
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
|
|
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
|
|
}
|
|
}
|
|
return parsedPrincipals, nil
|
|
}
|
|
}
|
|
|
|
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
|
|
return func(allowedPrincipals []string, validPrincipal string) bool {
|
|
for _, allowedPrincipal := range allowedPrincipals {
|
|
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
|
|
return true
|
|
}
|
|
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
|
|
requestedCertificateType := data.Get("cert_type").(string)
|
|
|
|
var certificateType uint32
|
|
switch requestedCertificateType {
|
|
case "user":
|
|
if !role.AllowUserCertificates {
|
|
return 0, errors.New("cert_type 'user' is not allowed by role")
|
|
}
|
|
certificateType = ssh.UserCert
|
|
case "host":
|
|
if !role.AllowHostCertificates {
|
|
return 0, errors.New("cert_type 'host' is not allowed by role")
|
|
}
|
|
certificateType = ssh.HostCert
|
|
default:
|
|
return 0, errors.New("cert_type must be either 'user' or 'host'")
|
|
}
|
|
|
|
return certificateType, nil
|
|
}
|
|
|
|
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
|
|
reqID := data.Get("key_id").(string)
|
|
|
|
if reqID != "" {
|
|
if !role.AllowUserKeyIDs {
|
|
return "", fmt.Errorf("setting key_id is not allowed by role")
|
|
}
|
|
return reqID, nil
|
|
}
|
|
|
|
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
|
|
if req.DisplayName == "" {
|
|
keyIDFormat = "vault-{{public_key_hash}}"
|
|
}
|
|
|
|
if role.KeyIDFormat != "" {
|
|
keyIDFormat = role.KeyIDFormat
|
|
}
|
|
|
|
keyID := substQuery(keyIDFormat, map[string]string{
|
|
"token_display_name": req.DisplayName,
|
|
"role_name": data.Get("role").(string),
|
|
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
|
|
})
|
|
|
|
return keyID, nil
|
|
}
|
|
|
|
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
|
|
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
|
|
if len(unparsedCriticalOptions) == 0 {
|
|
return role.DefaultCriticalOptions, nil
|
|
}
|
|
|
|
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
|
|
|
|
if role.AllowedCriticalOptions != "" {
|
|
notAllowedOptions := []string{}
|
|
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
|
|
|
|
for option := range criticalOptions {
|
|
if !strutil.StrListContains(allowedCriticalOptions, option) {
|
|
notAllowedOptions = append(notAllowedOptions, option)
|
|
}
|
|
}
|
|
|
|
if len(notAllowedOptions) != 0 {
|
|
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
|
|
}
|
|
}
|
|
|
|
return criticalOptions, nil
|
|
}
|
|
|
|
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, bool, error) {
|
|
unparsedExtensions := data.Get("extensions").(map[string]interface{})
|
|
extensions := make(map[string]string)
|
|
|
|
if len(unparsedExtensions) > 0 {
|
|
extensions := convertMapToStringValue(unparsedExtensions)
|
|
if role.AllowedExtensions == "*" {
|
|
// Allowed extensions was configured to allow all
|
|
return extensions, false, nil
|
|
}
|
|
|
|
notAllowed := []string{}
|
|
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
|
|
for extensionKey := range extensions {
|
|
if !strutil.StrListContains(allowedExtensions, extensionKey) {
|
|
notAllowed = append(notAllowed, extensionKey)
|
|
}
|
|
}
|
|
|
|
if len(notAllowed) != 0 {
|
|
return nil, false, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
|
|
}
|
|
return extensions, false, nil
|
|
}
|
|
|
|
haveMissingEntityInfoWithTemplatedExt := false
|
|
|
|
if role.DefaultExtensionsTemplate {
|
|
for extensionKey, extensionValue := range role.DefaultExtensions {
|
|
// Look for templating markers {{ .* }}
|
|
matched := containsTemplateRegex.MatchString(extensionValue)
|
|
if matched {
|
|
if req.EntityID != "" {
|
|
// Retrieve extension value based on template + entityID from request.
|
|
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
|
|
if err == nil {
|
|
// Template returned an extension value that we can use
|
|
extensions[extensionKey] = templateExtensionValue
|
|
} else {
|
|
return nil, false, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
|
|
}
|
|
} else {
|
|
haveMissingEntityInfoWithTemplatedExt = true
|
|
}
|
|
} else {
|
|
// Static extension value or err template
|
|
extensions[extensionKey] = extensionValue
|
|
}
|
|
}
|
|
} else {
|
|
extensions = role.DefaultExtensions
|
|
}
|
|
|
|
return extensions, haveMissingEntityInfoWithTemplatedExt, nil
|
|
}
|
|
|
|
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
|
|
var ttl, maxTTL time.Duration
|
|
var err error
|
|
|
|
ttlRaw, specifiedTTL := data.GetOk("ttl")
|
|
if specifiedTTL {
|
|
ttl = time.Duration(ttlRaw.(int)) * time.Second
|
|
} else {
|
|
ttl, err = parseutil.ParseDurationSecond(role.TTL)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
if ttl == 0 {
|
|
ttl = b.System().DefaultLeaseTTL()
|
|
}
|
|
|
|
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if maxTTL == 0 {
|
|
maxTTL = b.System().MaxLeaseTTL()
|
|
}
|
|
|
|
if ttl > maxTTL {
|
|
// Don't error if they were using system defaults, only error if
|
|
// they specifically chose a bad TTL
|
|
if !specifiedTTL {
|
|
ttl = maxTTL
|
|
} else {
|
|
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
|
|
}
|
|
}
|
|
|
|
return ttl, nil
|
|
}
|
|
|
|
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
|
|
if len(role.AllowedUserKeyTypesLengths) != 0 {
|
|
var keyType string
|
|
var keyBits int
|
|
|
|
switch k := publickey.(type) {
|
|
case ssh.CryptoPublicKey:
|
|
ff := k.CryptoPublicKey()
|
|
switch k := ff.(type) {
|
|
case *rsa.PublicKey:
|
|
keyType = "rsa"
|
|
keyBits = k.N.BitLen()
|
|
case *dsa.PublicKey:
|
|
keyType = "dsa"
|
|
keyBits = k.Parameters.P.BitLen()
|
|
case *ecdsa.PublicKey:
|
|
keyType = "ecdsa"
|
|
keyBits = k.Curve.Params().BitSize
|
|
case ed25519.PublicKey:
|
|
keyType = "ed25519"
|
|
default:
|
|
return fmt.Errorf("public key type of %s is not allowed", keyType)
|
|
}
|
|
default:
|
|
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
|
|
}
|
|
|
|
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
|
|
|
|
var present bool
|
|
var pass bool
|
|
for _, kstr := range keyTypeToMapKey[keyType] {
|
|
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
present = true
|
|
|
|
for _, value := range allowed_values {
|
|
if keyType == "rsa" || keyType == "dsa" {
|
|
// Regardless of map naming, we always need to validate the
|
|
// bit length of RSA and DSA keys. Use the keyType flag to
|
|
if keyBits == value {
|
|
pass = true
|
|
}
|
|
} else if kstr == "ec" || kstr == "ecdsa" {
|
|
// If the map string is "ecdsa", we have to validate the keyBits
|
|
// are a match for an allowed value, meaning that our curve
|
|
// is allowed. This isn't necessary when a named curve (e.g.
|
|
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
|
|
// because keyBits is already specified in the kstr. Thus,
|
|
// we have conditioned around kstr and not keyType (like with
|
|
// rsa or dsa).
|
|
if keyBits == value {
|
|
pass = true
|
|
}
|
|
} else {
|
|
// We get here in two cases: we have a algo-named EC key
|
|
// matching a format specifier in the key map (e.g., a P-256
|
|
// key with a KeyAlgoECDSA256 entry in the map) or we have a
|
|
// ed25519 key (which is always allowed).
|
|
pass = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if !present {
|
|
return fmt.Errorf("key of type %s is not allowed", keyType)
|
|
}
|
|
|
|
if !pass {
|
|
return fmt.Errorf("key is of an invalid size: %v", keyBits)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errMsg, ok := r.(string)
|
|
if ok {
|
|
retCert = nil
|
|
retErr = errors.New(errMsg)
|
|
}
|
|
}
|
|
}()
|
|
|
|
serialNumber, err := certutil.GenerateSerialNumber()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
|
|
}
|
|
|
|
// prepare certificate for signing
|
|
nonce := make([]byte, 32)
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
|
|
}
|
|
certificate := &ssh.Certificate{
|
|
Serial: serialNumber.Uint64(),
|
|
Key: b.PublicKey,
|
|
KeyId: b.KeyID,
|
|
ValidPrincipals: b.ValidPrincipals,
|
|
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
|
|
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
|
|
CertType: b.CertificateType,
|
|
Permissions: ssh.Permissions{
|
|
CriticalOptions: b.CriticalOptions,
|
|
Extensions: b.Extensions,
|
|
},
|
|
Nonce: nonce,
|
|
SignatureKey: sshAlgorithmSigner.PublicKey(),
|
|
}
|
|
|
|
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
|
|
out := certificate.Marshal()
|
|
// Drop trailing signature length.
|
|
certificateBytes := out[:len(out)-4]
|
|
|
|
algo := b.Role.AlgorithmSigner
|
|
|
|
// Handle the new default algorithm selection process correctly.
|
|
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
|
|
algo = ssh.SigAlgoRSASHA2256
|
|
} else if algo == DefaultAlgorithmSigner {
|
|
algo = ""
|
|
}
|
|
|
|
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
|
|
}
|
|
|
|
certificate.Signature = sig
|
|
|
|
return certificate, nil
|
|
}
|
|
|
|
func createKeyTypeToMapKey(keyType string, keyBits int) map[string][]string {
|
|
keyTypeToMapKey := map[string][]string{
|
|
"rsa": {"rsa", ssh.KeyAlgoRSA},
|
|
"dsa": {"dsa", ssh.KeyAlgoDSA},
|
|
"ecdsa": {"ecdsa", "ec"},
|
|
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
|
|
}
|
|
|
|
if keyType == "ecdsa" {
|
|
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
|
|
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
|
|
}
|
|
}
|
|
|
|
return keyTypeToMapKey
|
|
}
|