open-vault/builtin/logical/ssh/path_issue_sign.go
Ben Roberts d710f8e8dc
Evaluate ssh validprincipals user template before splitting (#16622)
The SSH secrets engine previously split the `validPrincipals` field
on comma, then if user templating is enabled, evaluated the
templates on each substring. This meant the identity template was only
ever allowed to return a single principal. There are use cases
where it would be helpful for identity metadata to contain a list
of valid principals and for the identity template to be able to inject
all of those as valid principals.

This change inverts the order of processing. First the template
is evaluated, and then the resulting string is split on commas.
This allows the identity template to return a single comma-separated
string with multiple permitted principals.

There is a potential security implication here, that if a user is
allowed to update their own identity metadata, they may be able to
elevate privileges where previously this was not possible.

Fixes #11038
2022-10-13 17:34:36 -05:00

559 lines
17 KiB
Go

package ssh
import (
"context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/ssh"
)
var containsTemplateRegex = regexp.MustCompile(`{{.+?}}`)
var ecCurveBitsToAlgoName = map[int]string{
256: ssh.KeyAlgoECDSA256,
384: ssh.KeyAlgoECDSA384,
521: ssh.KeyAlgoECDSA521,
}
// If the algorithm is not found, it could be that we have a curve
// that we haven't added a constant for yet. But they could allow it
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
// mapping rather than using a named SSH key type, so erring out here
// isn't advisable.
type creationBundle struct {
KeyID string
ValidPrincipals []string
PublicKey ssh.PublicKey
CertificateType uint32
TTL time.Duration
Signer ssh.Signer
Role *sshRole
CriticalOptions map[string]string
Extensions map[string]string
}
func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, publicKey ssh.PublicKey) (*logical.Response, error) {
// Note that these various functions always return "user errors" so we pass
// them as 4xx values
keyID, err := b.calculateKeyID(data, req, role, publicKey)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
certificateType, err := b.calculateCertificateType(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
var parsedPrincipals []string
if certificateType == ssh.HostCert {
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, role.AllowedDomainsTemplate, validateValidPrincipalForHosts(role))
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
} else {
defaultPrincipal := role.DefaultUser
if role.DefaultUserTemplate {
defaultPrincipal, err = b.renderPrincipal(role.DefaultUser, req)
if err != nil {
return nil, err
}
}
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, defaultPrincipal, role.AllowedUsers, role.AllowedUsersTemplate, strutil.StrListContains)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
ttl, err := b.calculateTTL(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
criticalOptions, err := b.calculateCriticalOptions(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
extensions, addExtTemplatingWarning, err := b.calculateExtensions(data, req, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to read CA private key: %w", err)
}
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
return nil, errors.New("failed to read CA private key")
}
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
if err != nil {
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
}
cBundle := creationBundle{
KeyID: keyID,
PublicKey: publicKey,
Signer: signer,
ValidPrincipals: parsedPrincipals,
TTL: ttl,
CertificateType: certificateType,
Role: role,
CriticalOptions: criticalOptions,
Extensions: extensions,
}
certificate, err := cBundle.sign()
if err != nil {
return nil, err
}
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
if len(signedSSHCertificate) == 0 {
return nil, errors.New("error marshaling signed certificate")
}
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": string(signedSSHCertificate),
},
}
if addExtTemplatingWarning {
response.AddWarning("default_extension templating enabled with at least one extension requiring identity templating. However, this request lacked identity entity information, causing one or more extensions to be skipped from the generated certificate.")
}
return response, nil
}
func (b *backend) renderPrincipal(principal string, req *logical.Request) (string, error) {
// Look for templating markers {{ .* }}
matched := containsTemplateRegex.MatchString(principal)
if matched {
if req.EntityID != "" {
// Retrieve principal based on template + entityID from request.
renderedPrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
if err != nil {
return "", fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
}
return renderedPrincipal, nil
}
}
// Static principal
return principal, nil
}
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, enableTemplating bool, validatePrincipal func([]string, string) bool) ([]string, error) {
validPrincipals := ""
validPrincipalsRaw, ok := data.GetOk("valid_principals")
if ok {
validPrincipals = validPrincipalsRaw.(string)
} else {
validPrincipals = defaultPrincipal
}
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
// Build list of allowed Principals from template and static principalsAllowedByRole
var allowedPrincipals []string
if enableTemplating {
rendered, err := b.renderPrincipal(principalsAllowedByRole, req)
if err != nil {
return nil, err
}
allowedPrincipals = strutil.RemoveDuplicates(strutil.ParseStringSlice(rendered, ","), false)
} else {
allowedPrincipals = strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false)
}
switch {
case len(parsedPrincipals) == 0:
// There is nothing to process
return nil, nil
case len(allowedPrincipals) == 0:
// User has requested principals to be set, but role is not configured
// with any principals
return nil, fmt.Errorf("role is not configured to allow any principals")
default:
// Role was explicitly configured to allow any principal.
if principalsAllowedByRole == "*" {
return parsedPrincipals, nil
}
for _, principal := range parsedPrincipals {
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
}
}
return parsedPrincipals, nil
}
}
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
return func(allowedPrincipals []string, validPrincipal string) bool {
for _, allowedPrincipal := range allowedPrincipals {
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
return true
}
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
return true
}
}
return false
}
}
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
requestedCertificateType := data.Get("cert_type").(string)
var certificateType uint32
switch requestedCertificateType {
case "user":
if !role.AllowUserCertificates {
return 0, errors.New("cert_type 'user' is not allowed by role")
}
certificateType = ssh.UserCert
case "host":
if !role.AllowHostCertificates {
return 0, errors.New("cert_type 'host' is not allowed by role")
}
certificateType = ssh.HostCert
default:
return 0, errors.New("cert_type must be either 'user' or 'host'")
}
return certificateType, nil
}
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
reqID := data.Get("key_id").(string)
if reqID != "" {
if !role.AllowUserKeyIDs {
return "", fmt.Errorf("setting key_id is not allowed by role")
}
return reqID, nil
}
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
if req.DisplayName == "" {
keyIDFormat = "vault-{{public_key_hash}}"
}
if role.KeyIDFormat != "" {
keyIDFormat = role.KeyIDFormat
}
keyID := substQuery(keyIDFormat, map[string]string{
"token_display_name": req.DisplayName,
"role_name": data.Get("role").(string),
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
})
return keyID, nil
}
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
if len(unparsedCriticalOptions) == 0 {
return role.DefaultCriticalOptions, nil
}
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
if role.AllowedCriticalOptions != "" {
notAllowedOptions := []string{}
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
for option := range criticalOptions {
if !strutil.StrListContains(allowedCriticalOptions, option) {
notAllowedOptions = append(notAllowedOptions, option)
}
}
if len(notAllowedOptions) != 0 {
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
}
}
return criticalOptions, nil
}
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, bool, error) {
unparsedExtensions := data.Get("extensions").(map[string]interface{})
extensions := make(map[string]string)
if len(unparsedExtensions) > 0 {
extensions := convertMapToStringValue(unparsedExtensions)
if role.AllowedExtensions == "*" {
// Allowed extensions was configured to allow all
return extensions, false, nil
}
notAllowed := []string{}
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
for extensionKey := range extensions {
if !strutil.StrListContains(allowedExtensions, extensionKey) {
notAllowed = append(notAllowed, extensionKey)
}
}
if len(notAllowed) != 0 {
return nil, false, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
}
return extensions, false, nil
}
haveMissingEntityInfoWithTemplatedExt := false
if role.DefaultExtensionsTemplate {
for extensionKey, extensionValue := range role.DefaultExtensions {
// Look for templating markers {{ .* }}
matched := containsTemplateRegex.MatchString(extensionValue)
if matched {
if req.EntityID != "" {
// Retrieve extension value based on template + entityID from request.
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
if err == nil {
// Template returned an extension value that we can use
extensions[extensionKey] = templateExtensionValue
} else {
return nil, false, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
}
} else {
haveMissingEntityInfoWithTemplatedExt = true
}
} else {
// Static extension value or err template
extensions[extensionKey] = extensionValue
}
}
} else {
extensions = role.DefaultExtensions
}
return extensions, haveMissingEntityInfoWithTemplatedExt, nil
}
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
var ttl, maxTTL time.Duration
var err error
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
}
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
}
}
return ttl, nil
}
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
if len(role.AllowedUserKeyTypesLengths) != 0 {
var keyType string
var keyBits int
switch k := publickey.(type) {
case ssh.CryptoPublicKey:
ff := k.CryptoPublicKey()
switch k := ff.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = k.N.BitLen()
case *dsa.PublicKey:
keyType = "dsa"
keyBits = k.Parameters.P.BitLen()
case *ecdsa.PublicKey:
keyType = "ecdsa"
keyBits = k.Curve.Params().BitSize
case ed25519.PublicKey:
keyType = "ed25519"
default:
return fmt.Errorf("public key type of %s is not allowed", keyType)
}
default:
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
}
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
var present bool
var pass bool
for _, kstr := range keyTypeToMapKey[keyType] {
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
if !ok {
continue
}
present = true
for _, value := range allowed_values {
if keyType == "rsa" || keyType == "dsa" {
// Regardless of map naming, we always need to validate the
// bit length of RSA and DSA keys. Use the keyType flag to
if keyBits == value {
pass = true
}
} else if kstr == "ec" || kstr == "ecdsa" {
// If the map string is "ecdsa", we have to validate the keyBits
// are a match for an allowed value, meaning that our curve
// is allowed. This isn't necessary when a named curve (e.g.
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
// because keyBits is already specified in the kstr. Thus,
// we have conditioned around kstr and not keyType (like with
// rsa or dsa).
if keyBits == value {
pass = true
}
} else {
// We get here in two cases: we have a algo-named EC key
// matching a format specifier in the key map (e.g., a P-256
// key with a KeyAlgoECDSA256 entry in the map) or we have a
// ed25519 key (which is always allowed).
pass = true
}
}
}
if !present {
return fmt.Errorf("key of type %s is not allowed", keyType)
}
if !pass {
return fmt.Errorf("key is of an invalid size: %v", keyBits)
}
}
return nil
}
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
defer func() {
if r := recover(); r != nil {
errMsg, ok := r.(string)
if ok {
retCert = nil
retErr = errors.New(errMsg)
}
}
}()
serialNumber, err := certutil.GenerateSerialNumber()
if err != nil {
return nil, err
}
now := time.Now()
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
if !ok {
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
}
// prepare certificate for signing
nonce := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
}
certificate := &ssh.Certificate{
Serial: serialNumber.Uint64(),
Key: b.PublicKey,
KeyId: b.KeyID,
ValidPrincipals: b.ValidPrincipals,
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
CertType: b.CertificateType,
Permissions: ssh.Permissions{
CriticalOptions: b.CriticalOptions,
Extensions: b.Extensions,
},
Nonce: nonce,
SignatureKey: sshAlgorithmSigner.PublicKey(),
}
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
out := certificate.Marshal()
// Drop trailing signature length.
certificateBytes := out[:len(out)-4]
algo := b.Role.AlgorithmSigner
// Handle the new default algorithm selection process correctly.
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
algo = ssh.SigAlgoRSASHA2256
} else if algo == DefaultAlgorithmSigner {
algo = ""
}
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
if err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
}
certificate.Signature = sig
return certificate, nil
}
func createKeyTypeToMapKey(keyType string, keyBits int) map[string][]string {
keyTypeToMapKey := map[string][]string{
"rsa": {"rsa", ssh.KeyAlgoRSA},
"dsa": {"dsa", ssh.KeyAlgoDSA},
"ecdsa": {"ecdsa", "ec"},
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
}
if keyType == "ecdsa" {
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
}
}
return keyTypeToMapKey
}