2023-03-15 16:00:52 +00:00
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
2022-06-10 13:48:19 +00:00
package ssh
import (
"context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/ssh"
)
2022-06-17 15:06:17 +00:00
var containsTemplateRegex = regexp . MustCompile ( ` {{ . + ? }} ` )
2022-06-10 13:48:19 +00:00
var ecCurveBitsToAlgoName = map [ int ] string {
256 : ssh . KeyAlgoECDSA256 ,
384 : ssh . KeyAlgoECDSA384 ,
521 : ssh . KeyAlgoECDSA521 ,
}
// If the algorithm is not found, it could be that we have a curve
// that we haven't added a constant for yet. But they could allow it
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
// mapping rather than using a named SSH key type, so erring out here
// isn't advisable.
type creationBundle struct {
KeyID string
ValidPrincipals [ ] string
PublicKey ssh . PublicKey
CertificateType uint32
TTL time . Duration
Signer ssh . Signer
Role * sshRole
CriticalOptions map [ string ] string
Extensions map [ string ] string
}
func ( b * backend ) pathSignIssueCertificateHelper ( ctx context . Context , req * logical . Request , data * framework . FieldData , role * sshRole , publicKey ssh . PublicKey ) ( * logical . Response , error ) {
// Note that these various functions always return "user errors" so we pass
// them as 4xx values
keyID , err := b . calculateKeyID ( data , req , role , publicKey )
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
certificateType , err := b . calculateCertificateType ( data , role )
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
var parsedPrincipals [ ] string
if certificateType == ssh . HostCert {
2022-08-16 19:59:29 +00:00
parsedPrincipals , err = b . calculateValidPrincipals ( data , req , role , "" , role . AllowedDomains , role . AllowedDomainsTemplate , validateValidPrincipalForHosts ( role ) )
2022-06-10 13:48:19 +00:00
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
} else {
2022-07-29 13:45:52 +00:00
defaultPrincipal := role . DefaultUser
if role . DefaultUserTemplate {
defaultPrincipal , err = b . renderPrincipal ( role . DefaultUser , req )
if err != nil {
2022-07-29 14:18:22 +00:00
return nil , err
2022-07-29 13:45:52 +00:00
}
}
2022-08-16 19:59:29 +00:00
parsedPrincipals , err = b . calculateValidPrincipals ( data , req , role , defaultPrincipal , role . AllowedUsers , role . AllowedUsersTemplate , strutil . StrListContains )
2022-06-10 13:48:19 +00:00
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
}
ttl , err := b . calculateTTL ( data , role )
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
criticalOptions , err := b . calculateCriticalOptions ( data , role )
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
2022-10-06 18:00:56 +00:00
extensions , addExtTemplatingWarning , err := b . calculateExtensions ( data , req , role )
2022-06-10 13:48:19 +00:00
if err != nil {
return logical . ErrorResponse ( err . Error ( ) ) , nil
}
privateKeyEntry , err := caKey ( ctx , req . Storage , caPrivateKey )
if err != nil {
return nil , fmt . Errorf ( "failed to read CA private key: %w" , err )
}
if privateKeyEntry == nil || privateKeyEntry . Key == "" {
return nil , errors . New ( "failed to read CA private key" )
}
signer , err := ssh . ParsePrivateKey ( [ ] byte ( privateKeyEntry . Key ) )
if err != nil {
return nil , fmt . Errorf ( "failed to parse stored CA private key: %w" , err )
}
cBundle := creationBundle {
KeyID : keyID ,
PublicKey : publicKey ,
Signer : signer ,
ValidPrincipals : parsedPrincipals ,
TTL : ttl ,
CertificateType : certificateType ,
Role : role ,
CriticalOptions : criticalOptions ,
Extensions : extensions ,
}
certificate , err := cBundle . sign ( )
if err != nil {
return nil , err
}
signedSSHCertificate := ssh . MarshalAuthorizedKey ( certificate )
if len ( signedSSHCertificate ) == 0 {
return nil , errors . New ( "error marshaling signed certificate" )
}
response := & logical . Response {
Data : map [ string ] interface { } {
"serial_number" : strconv . FormatUint ( certificate . Serial , 16 ) ,
"signed_key" : string ( signedSSHCertificate ) ,
} ,
}
2022-10-06 18:00:56 +00:00
if addExtTemplatingWarning {
response . AddWarning ( "default_extension templating enabled with at least one extension requiring identity templating. However, this request lacked identity entity information, causing one or more extensions to be skipped from the generated certificate." )
}
2022-06-10 13:48:19 +00:00
return response , nil
}
2022-07-29 13:45:52 +00:00
func ( b * backend ) renderPrincipal ( principal string , req * logical . Request ) ( string , error ) {
// Look for templating markers {{ .* }}
matched := containsTemplateRegex . MatchString ( principal )
if matched {
if req . EntityID != "" {
// Retrieve principal based on template + entityID from request.
renderedPrincipal , err := framework . PopulateIdentityTemplate ( principal , req . EntityID , b . System ( ) )
if err != nil {
return "" , fmt . Errorf ( "template '%s' could not be rendered -> %s" , principal , err )
}
return renderedPrincipal , nil
}
}
// Static principal
return principal , nil
}
2022-08-16 19:59:29 +00:00
func ( b * backend ) calculateValidPrincipals ( data * framework . FieldData , req * logical . Request , role * sshRole , defaultPrincipal , principalsAllowedByRole string , enableTemplating bool , validatePrincipal func ( [ ] string , string ) bool ) ( [ ] string , error ) {
2022-06-10 13:48:19 +00:00
validPrincipals := ""
validPrincipalsRaw , ok := data . GetOk ( "valid_principals" )
if ok {
validPrincipals = validPrincipalsRaw . ( string )
} else {
validPrincipals = defaultPrincipal
}
parsedPrincipals := strutil . RemoveDuplicates ( strutil . ParseStringSlice ( validPrincipals , "," ) , false )
// Build list of allowed Principals from template and static principalsAllowedByRole
var allowedPrincipals [ ] string
2022-10-13 22:34:36 +00:00
if enableTemplating {
rendered , err := b . renderPrincipal ( principalsAllowedByRole , req )
if err != nil {
return nil , err
2022-06-10 13:48:19 +00:00
}
2022-10-13 22:34:36 +00:00
allowedPrincipals = strutil . RemoveDuplicates ( strutil . ParseStringSlice ( rendered , "," ) , false )
} else {
allowedPrincipals = strutil . RemoveDuplicates ( strutil . ParseStringSlice ( principalsAllowedByRole , "," ) , false )
2022-06-10 13:48:19 +00:00
}
switch {
case len ( parsedPrincipals ) == 0 :
// There is nothing to process
return nil , nil
case len ( allowedPrincipals ) == 0 :
// User has requested principals to be set, but role is not configured
// with any principals
return nil , fmt . Errorf ( "role is not configured to allow any principals" )
default :
// Role was explicitly configured to allow any principal.
if principalsAllowedByRole == "*" {
return parsedPrincipals , nil
}
for _ , principal := range parsedPrincipals {
if ! validatePrincipal ( strutil . RemoveDuplicates ( allowedPrincipals , false ) , principal ) {
return nil , fmt . Errorf ( "%v is not a valid value for valid_principals" , principal )
}
}
return parsedPrincipals , nil
}
}
func validateValidPrincipalForHosts ( role * sshRole ) func ( [ ] string , string ) bool {
return func ( allowedPrincipals [ ] string , validPrincipal string ) bool {
for _ , allowedPrincipal := range allowedPrincipals {
if allowedPrincipal == validPrincipal && role . AllowBareDomains {
return true
}
if role . AllowSubdomains && strings . HasSuffix ( validPrincipal , "." + allowedPrincipal ) {
return true
}
}
return false
}
}
func ( b * backend ) calculateCertificateType ( data * framework . FieldData , role * sshRole ) ( uint32 , error ) {
requestedCertificateType := data . Get ( "cert_type" ) . ( string )
var certificateType uint32
switch requestedCertificateType {
case "user" :
if ! role . AllowUserCertificates {
return 0 , errors . New ( "cert_type 'user' is not allowed by role" )
}
certificateType = ssh . UserCert
case "host" :
if ! role . AllowHostCertificates {
return 0 , errors . New ( "cert_type 'host' is not allowed by role" )
}
certificateType = ssh . HostCert
default :
return 0 , errors . New ( "cert_type must be either 'user' or 'host'" )
}
return certificateType , nil
}
func ( b * backend ) calculateKeyID ( data * framework . FieldData , req * logical . Request , role * sshRole , pubKey ssh . PublicKey ) ( string , error ) {
reqID := data . Get ( "key_id" ) . ( string )
if reqID != "" {
if ! role . AllowUserKeyIDs {
return "" , fmt . Errorf ( "setting key_id is not allowed by role" )
}
return reqID , nil
}
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
if req . DisplayName == "" {
keyIDFormat = "vault-{{public_key_hash}}"
}
if role . KeyIDFormat != "" {
keyIDFormat = role . KeyIDFormat
}
keyID := substQuery ( keyIDFormat , map [ string ] string {
"token_display_name" : req . DisplayName ,
"role_name" : data . Get ( "role" ) . ( string ) ,
"public_key_hash" : fmt . Sprintf ( "%x" , sha256 . Sum256 ( pubKey . Marshal ( ) ) ) ,
} )
return keyID , nil
}
func ( b * backend ) calculateCriticalOptions ( data * framework . FieldData , role * sshRole ) ( map [ string ] string , error ) {
unparsedCriticalOptions := data . Get ( "critical_options" ) . ( map [ string ] interface { } )
if len ( unparsedCriticalOptions ) == 0 {
return role . DefaultCriticalOptions , nil
}
criticalOptions := convertMapToStringValue ( unparsedCriticalOptions )
if role . AllowedCriticalOptions != "" {
notAllowedOptions := [ ] string { }
allowedCriticalOptions := strings . Split ( role . AllowedCriticalOptions , "," )
for option := range criticalOptions {
if ! strutil . StrListContains ( allowedCriticalOptions , option ) {
notAllowedOptions = append ( notAllowedOptions , option )
}
}
if len ( notAllowedOptions ) != 0 {
return nil , fmt . Errorf ( "critical options not on allowed list: %v" , notAllowedOptions )
}
}
return criticalOptions , nil
}
2022-10-06 18:00:56 +00:00
func ( b * backend ) calculateExtensions ( data * framework . FieldData , req * logical . Request , role * sshRole ) ( map [ string ] string , bool , error ) {
2022-06-10 13:48:19 +00:00
unparsedExtensions := data . Get ( "extensions" ) . ( map [ string ] interface { } )
extensions := make ( map [ string ] string )
if len ( unparsedExtensions ) > 0 {
extensions := convertMapToStringValue ( unparsedExtensions )
if role . AllowedExtensions == "*" {
// Allowed extensions was configured to allow all
2022-10-06 18:00:56 +00:00
return extensions , false , nil
2022-06-10 13:48:19 +00:00
}
notAllowed := [ ] string { }
allowedExtensions := strings . Split ( role . AllowedExtensions , "," )
for extensionKey := range extensions {
if ! strutil . StrListContains ( allowedExtensions , extensionKey ) {
notAllowed = append ( notAllowed , extensionKey )
}
}
if len ( notAllowed ) != 0 {
2022-10-06 18:00:56 +00:00
return nil , false , fmt . Errorf ( "extensions %v are not on allowed list" , notAllowed )
2022-06-10 13:48:19 +00:00
}
2022-10-06 18:00:56 +00:00
return extensions , false , nil
2022-06-10 13:48:19 +00:00
}
2022-10-06 18:00:56 +00:00
haveMissingEntityInfoWithTemplatedExt := false
2022-06-10 13:48:19 +00:00
if role . DefaultExtensionsTemplate {
for extensionKey , extensionValue := range role . DefaultExtensions {
// Look for templating markers {{ .* }}
2022-06-17 15:06:17 +00:00
matched := containsTemplateRegex . MatchString ( extensionValue )
2022-06-10 13:48:19 +00:00
if matched {
if req . EntityID != "" {
// Retrieve extension value based on template + entityID from request.
templateExtensionValue , err := framework . PopulateIdentityTemplate ( extensionValue , req . EntityID , b . System ( ) )
if err == nil {
// Template returned an extension value that we can use
extensions [ extensionKey ] = templateExtensionValue
} else {
2022-10-06 18:00:56 +00:00
return nil , false , fmt . Errorf ( "template '%s' could not be rendered -> %s" , extensionValue , err )
2022-06-10 13:48:19 +00:00
}
2022-10-06 18:00:56 +00:00
} else {
haveMissingEntityInfoWithTemplatedExt = true
2022-06-10 13:48:19 +00:00
}
} else {
// Static extension value or err template
extensions [ extensionKey ] = extensionValue
}
}
} else {
extensions = role . DefaultExtensions
}
2022-10-06 18:00:56 +00:00
return extensions , haveMissingEntityInfoWithTemplatedExt , nil
2022-06-10 13:48:19 +00:00
}
func ( b * backend ) calculateTTL ( data * framework . FieldData , role * sshRole ) ( time . Duration , error ) {
var ttl , maxTTL time . Duration
var err error
ttlRaw , specifiedTTL := data . GetOk ( "ttl" )
if specifiedTTL {
ttl = time . Duration ( ttlRaw . ( int ) ) * time . Second
} else {
ttl , err = parseutil . ParseDurationSecond ( role . TTL )
if err != nil {
return 0 , err
}
}
if ttl == 0 {
ttl = b . System ( ) . DefaultLeaseTTL ( )
}
maxTTL , err = parseutil . ParseDurationSecond ( role . MaxTTL )
if err != nil {
return 0 , err
}
if maxTTL == 0 {
maxTTL = b . System ( ) . MaxLeaseTTL ( )
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if ! specifiedTTL {
ttl = maxTTL
} else {
return 0 , fmt . Errorf ( "ttl is larger than maximum allowed %d" , maxTTL / time . Second )
}
}
return ttl , nil
}
func ( b * backend ) validateSignedKeyRequirements ( publickey ssh . PublicKey , role * sshRole ) error {
if len ( role . AllowedUserKeyTypesLengths ) != 0 {
var keyType string
var keyBits int
switch k := publickey . ( type ) {
case ssh . CryptoPublicKey :
ff := k . CryptoPublicKey ( )
switch k := ff . ( type ) {
case * rsa . PublicKey :
keyType = "rsa"
keyBits = k . N . BitLen ( )
case * dsa . PublicKey :
keyType = "dsa"
keyBits = k . Parameters . P . BitLen ( )
case * ecdsa . PublicKey :
keyType = "ecdsa"
keyBits = k . Curve . Params ( ) . BitSize
case ed25519 . PublicKey :
keyType = "ed25519"
default :
return fmt . Errorf ( "public key type of %s is not allowed" , keyType )
}
default :
return fmt . Errorf ( "pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)" , k )
}
keyTypeToMapKey := createKeyTypeToMapKey ( keyType , keyBits )
var present bool
var pass bool
for _ , kstr := range keyTypeToMapKey [ keyType ] {
allowed_values , ok := role . AllowedUserKeyTypesLengths [ kstr ]
if ! ok {
continue
}
present = true
for _ , value := range allowed_values {
if keyType == "rsa" || keyType == "dsa" {
// Regardless of map naming, we always need to validate the
// bit length of RSA and DSA keys. Use the keyType flag to
if keyBits == value {
pass = true
}
} else if kstr == "ec" || kstr == "ecdsa" {
// If the map string is "ecdsa", we have to validate the keyBits
// are a match for an allowed value, meaning that our curve
// is allowed. This isn't necessary when a named curve (e.g.
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
// because keyBits is already specified in the kstr. Thus,
// we have conditioned around kstr and not keyType (like with
// rsa or dsa).
if keyBits == value {
pass = true
}
} else {
// We get here in two cases: we have a algo-named EC key
// matching a format specifier in the key map (e.g., a P-256
// key with a KeyAlgoECDSA256 entry in the map) or we have a
// ed25519 key (which is always allowed).
pass = true
}
}
}
if ! present {
return fmt . Errorf ( "key of type %s is not allowed" , keyType )
}
if ! pass {
return fmt . Errorf ( "key is of an invalid size: %v" , keyBits )
}
}
return nil
}
func ( b * creationBundle ) sign ( ) ( retCert * ssh . Certificate , retErr error ) {
defer func ( ) {
if r := recover ( ) ; r != nil {
errMsg , ok := r . ( string )
if ok {
retCert = nil
retErr = errors . New ( errMsg )
}
}
} ( )
serialNumber , err := certutil . GenerateSerialNumber ( )
if err != nil {
return nil , err
}
now := time . Now ( )
sshAlgorithmSigner , ok := b . Signer . ( ssh . AlgorithmSigner )
if ! ok {
return nil , fmt . Errorf ( "failed to generate signed SSH key: signer is not an AlgorithmSigner" )
}
// prepare certificate for signing
nonce := make ( [ ] byte , 32 )
if _ , err := io . ReadFull ( rand . Reader , nonce ) ; err != nil {
return nil , fmt . Errorf ( "failed to generate signed SSH key: error generating random nonce" )
}
certificate := & ssh . Certificate {
Serial : serialNumber . Uint64 ( ) ,
Key : b . PublicKey ,
KeyId : b . KeyID ,
ValidPrincipals : b . ValidPrincipals ,
ValidAfter : uint64 ( now . Add ( - b . Role . NotBeforeDuration ) . In ( time . UTC ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( b . TTL ) . In ( time . UTC ) . Unix ( ) ) ,
CertType : b . CertificateType ,
Permissions : ssh . Permissions {
CriticalOptions : b . CriticalOptions ,
Extensions : b . Extensions ,
} ,
Nonce : nonce ,
SignatureKey : sshAlgorithmSigner . PublicKey ( ) ,
}
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
out := certificate . Marshal ( )
// Drop trailing signature length.
certificateBytes := out [ : len ( out ) - 4 ]
algo := b . Role . AlgorithmSigner
// Handle the new default algorithm selection process correctly.
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner . PublicKey ( ) . Type ( ) == ssh . KeyAlgoRSA {
algo = ssh . SigAlgoRSASHA2256
} else if algo == DefaultAlgorithmSigner {
algo = ""
}
sig , err := sshAlgorithmSigner . SignWithAlgorithm ( rand . Reader , certificateBytes , algo )
if err != nil {
return nil , fmt . Errorf ( "failed to generate signed SSH key: sign error: %w" , err )
}
certificate . Signature = sig
return certificate , nil
}
func createKeyTypeToMapKey ( keyType string , keyBits int ) map [ string ] [ ] string {
keyTypeToMapKey := map [ string ] [ ] string {
"rsa" : { "rsa" , ssh . KeyAlgoRSA } ,
"dsa" : { "dsa" , ssh . KeyAlgoDSA } ,
"ecdsa" : { "ecdsa" , "ec" } ,
"ed25519" : { "ed25519" , ssh . KeyAlgoED25519 } ,
}
if keyType == "ecdsa" {
if algo , ok := ecCurveBitsToAlgoName [ keyBits ] ; ok {
keyTypeToMapKey [ keyType ] = append ( keyTypeToMapKey [ keyType ] , algo )
}
}
return keyTypeToMapKey
}