SSH secrets engine - Enabled creation of key pairs (CA Mode) (#15561)

* Handle func

* Update - check if key_type and key_bits are allowed

* Update - fields

* Generating keys based on provided key_type and key_bits

* Returning signed key

* Refactor

* Refactor update to common logic function

* Descriptions

* Tests added

* Suggested changes and tests added and refactored

* Suggested changes and fmt run

* File refactoring

* Changelog file

* Update changelog/15561.txt

Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com>

* Suggested changes - consistent returns and additional info to test messages

* ssh issue key pair documentation

Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com>
This commit is contained in:
Gabriel Santos 2022-06-10 14:48:19 +01:00 committed by GitHub
parent 17eed2a814
commit 57eeb33faa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 894 additions and 519 deletions

View File

@ -61,6 +61,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathVerify(&b), pathVerify(&b),
pathConfigCA(&b), pathConfigCA(&b),
pathSign(&b), pathSign(&b),
pathIssue(&b),
pathFetchPublicKey(&b), pathFetchPublicKey(&b),
}, },

View File

@ -1776,6 +1776,50 @@ func TestSSHBackend_ValidateNotBeforeDuration(t *testing.T) {
logicaltest.Test(t, testCase) logicaltest.Test(t, testCase)
} }
func TestSSHBackend_IssueSign(t *testing.T) {
config := logical.TestBackendConfig()
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatalf("Cannot create backend: %s", err)
}
testCase := logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
configCaStep(testCAPublicKey, testCAPrivateKey),
createRoleStep("testing", map[string]interface{}{
"key_type": "otp",
"default_user": "user",
}),
// Key pair not issued with invalid role key type
issueSSHKeyPairStep("testing", "rsa", 0, true, "role key type 'otp' not allowed to issue key pairs"),
createRoleStep("testing", map[string]interface{}{
"key_type": "ca",
"allow_user_key_ids": false,
"allow_user_certificates": true,
"allowed_user_key_lengths": map[string]interface{}{
"ssh-rsa": []int{2048, 3072, 4096},
"ecdsa-sha2-nistp521": 0,
"ed25519": 0,
},
}),
// Key_type not in allowed_user_key_types_lengths
issueSSHKeyPairStep("testing", "ec", 256, true, "provided key_type value not in allowed_user_key_types"),
// Key_bits not in allowed_user_key_types_lengths for provided key_type
issueSSHKeyPairStep("testing", "rsa", 2560, true, "provided key_bits value not in list of role's allowed_user_key_types"),
// key_type `rsa` and key_bits `2048` successfully created
issueSSHKeyPairStep("testing", "rsa", 2048, false, ""),
// key_type `ed22519` and key_bits `0` successfully created
issueSSHKeyPairStep("testing", "ed25519", 0, false, ""),
},
}
logicaltest.Test(t, testCase)
}
func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster, string) { func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster, string) {
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{ CredentialBackends: map[string]logical.Factory{
@ -1847,7 +1891,8 @@ func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster,
} }
func testAllowedUsersTemplate(t *testing.T, testAllowedUsersTemplate string, func testAllowedUsersTemplate(t *testing.T, testAllowedUsersTemplate string,
expectedValidPrincipal string, testEntityMetadata map[string]string) { expectedValidPrincipal string, testEntityMetadata map[string]string,
) {
cluster, userpassToken := getSshCaTestCluster(t, testUserName) cluster, userpassToken := getSshCaTestCluster(t, testUserName)
defer cluster.Cleanup() defer cluster.Cleanup()
client := cluster.Cores[0].Client client := cluster.Cores[0].Client
@ -1926,7 +1971,8 @@ func signCertificateStep(
role, keyID string, certType int, validPrincipals []string, role, keyID string, certType int, validPrincipals []string,
criticalOptionPermissions, extensionPermissions map[string]string, criticalOptionPermissions, extensionPermissions map[string]string,
ttl time.Duration, ttl time.Duration,
requestParameters map[string]interface{}) logicaltest.TestStep { requestParameters map[string]interface{},
) logicaltest.TestStep {
return logicaltest.TestStep{ return logicaltest.TestStep{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sign/" + role, Path: "sign/" + role,
@ -1955,6 +2001,42 @@ func signCertificateStep(
} }
} }
func issueSSHKeyPairStep(role, keyType string, keyBits int, expectError bool, errorMsg string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "issue/" + role,
Data: map[string]interface{}{
"key_type": keyType,
"key_bits": keyBits,
},
ErrorOk: true,
Check: func(resp *logical.Response) error {
if expectError {
var err error
if resp.Data["error"] != errorMsg {
err = fmt.Errorf("actual error message \"%s\" different from expected error message \"%s\"", resp.Data["error"], errorMsg)
}
return err
}
if resp.IsError() {
return fmt.Errorf("unexpected error response returned: %v", resp.Error())
}
if resp.Data["private_key_type"] != keyType {
return fmt.Errorf("response private_key_type (%s) does not match the provided key_type (%s)", resp.Data["private_key_type"], keyType)
}
if resp.Data["signed_key"] == "" {
return errors.New("certificate/signed_key should not be empty")
}
return nil
},
}
}
func validateSSHCertificate(cert *ssh.Certificate, keyID string, certType int, validPrincipals []string, criticalOptionPermissions, extensionPermissions map[string]string, func validateSSHCertificate(cert *ssh.Certificate, keyID string, certType int, validPrincipals []string, criticalOptionPermissions, extensionPermissions map[string]string,
ttl time.Duration, ttl time.Duration,
) error { ) error {

View File

@ -0,0 +1,180 @@
package ssh
import (
"context"
"crypto/rand"
"errors"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
type keySpecs struct {
Type string
Bits int
}
func pathIssue(b *backend) *framework.Path {
return &framework.Path{
Pattern: "issue/" + framework.GenericNameWithAtRegex("role"),
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathIssue,
},
},
Fields: map[string]*framework.FieldSchema{
"role": {
Type: framework.TypeString,
Description: `The desired role with configuration for this request.`,
},
"key_type": {
Type: framework.TypeString,
Description: "Specifies the desired key type; must be `rsa`, `ed25519` or `ec`",
Default: "rsa",
},
"key_bits": {
Type: framework.TypeInt,
Description: "Specifies the number of bits to use for the generated keys.",
Default: 0,
},
"ttl": {
Type: framework.TypeDurationSecond,
Description: `The requested Time To Live for the SSH certificate;
sets the expiration date. If not specified
the role default, backend default, or system
default TTL is used, in that order. Cannot
be later than the role max TTL.`,
},
"valid_principals": {
Type: framework.TypeString,
Description: `Valid principals, either usernames or hostnames, that the certificate should be signed for.`,
},
"cert_type": {
Type: framework.TypeString,
Description: `Type of certificate to be created; either "user" or "host".`,
Default: "user",
},
"key_id": {
Type: framework.TypeString,
Description: `Key id that the created certificate should have. If not specified, the display name of the token will be used.`,
},
"critical_options": {
Type: framework.TypeMap,
Description: `Critical options that the certificate should be signed for.`,
},
"extensions": {
Type: framework.TypeMap,
Description: `Extensions that the certificate should be signed for.`,
},
},
HelpSynopsis: pathIssueHelpSyn,
HelpDescription: pathIssueHelpDesc,
}
}
func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Get the role
roleName := data.Get("role").(string)
role, err := b.getRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", roleName)), nil
}
if role.KeyType != "ca" {
return logical.ErrorResponse("role key type '%s' not allowed to issue key pairs", role.KeyType), nil
}
// Validate and extract key specifications
keySpecs, err := extractKeySpecs(role, data)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Issue certificate
return b.pathIssueCertificate(ctx, req, data, role, keySpecs)
}
func (b *backend) pathIssueCertificate(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, keySpecs *keySpecs) (*logical.Response, error) {
publicKey, privateKey, err := generateSSHKeyPair(rand.Reader, keySpecs.Type, keySpecs.Bits)
if err != nil {
return nil, err
}
// Sign key
userPublicKey, err := parsePublicSSHKey(publicKey)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse public_key as SSH key: %s", err)), nil
}
response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Additional to sign response
response.Data["private_key"] = privateKey
response.Data["private_key_type"] = keySpecs.Type
return response, nil
}
func extractKeySpecs(role *sshRole, data *framework.FieldData) (*keySpecs, error) {
keyType := data.Get("key_type").(string)
keyBits := data.Get("key_bits").(int)
keySpecs := keySpecs{
Type: keyType,
Bits: keyBits,
}
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
if len(role.AllowedUserKeyTypesLengths) != 0 {
var keyAllowed bool
var bitsAllowed bool
keyTypeAliasesLoop:
for _, keyTypeAlias := range keyTypeToMapKey[keyType] {
allowedValues, allowed := role.AllowedUserKeyTypesLengths[keyTypeAlias]
if !allowed {
continue
}
keyAllowed = true
for _, value := range allowedValues {
if value == keyBits {
bitsAllowed = true
break keyTypeAliasesLoop
}
}
}
if !keyAllowed {
return nil, errors.New("provided key_type value not in allowed_user_key_types")
}
if !bitsAllowed {
return nil, errors.New("provided key_bits value not in list of role's allowed_user_key_types")
}
}
return &keySpecs, nil
}
const pathIssueHelpSyn = `
Request a certificate using a certain role with the provided details.
`
const pathIssueHelpDesc = `
This path allows requesting a certificate to be issued according to the
policy of the given role. The certificate will only be issued if the
requested details are allowed by the role policy.
This path returns a certificate and a private key. If you want a workflow
that does not expose a private key, generate a CSR locally and use the
sign path instead.
`

View File

@ -0,0 +1,539 @@
package ssh
import (
"context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/ssh"
)
var ecCurveBitsToAlgoName = map[int]string{
256: ssh.KeyAlgoECDSA256,
384: ssh.KeyAlgoECDSA384,
521: ssh.KeyAlgoECDSA521,
}
// If the algorithm is not found, it could be that we have a curve
// that we haven't added a constant for yet. But they could allow it
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
// mapping rather than using a named SSH key type, so erring out here
// isn't advisable.
type creationBundle struct {
KeyID string
ValidPrincipals []string
PublicKey ssh.PublicKey
CertificateType uint32
TTL time.Duration
Signer ssh.Signer
Role *sshRole
CriticalOptions map[string]string
Extensions map[string]string
}
func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, publicKey ssh.PublicKey) (*logical.Response, error) {
// Note that these various functions always return "user errors" so we pass
// them as 4xx values
keyID, err := b.calculateKeyID(data, req, role, publicKey)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
certificateType, err := b.calculateCertificateType(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
var parsedPrincipals []string
if certificateType == ssh.HostCert {
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role))
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
} else {
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
ttl, err := b.calculateTTL(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
criticalOptions, err := b.calculateCriticalOptions(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
extensions, err := b.calculateExtensions(data, req, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to read CA private key: %w", err)
}
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
return nil, errors.New("failed to read CA private key")
}
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
if err != nil {
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
}
cBundle := creationBundle{
KeyID: keyID,
PublicKey: publicKey,
Signer: signer,
ValidPrincipals: parsedPrincipals,
TTL: ttl,
CertificateType: certificateType,
Role: role,
CriticalOptions: criticalOptions,
Extensions: extensions,
}
certificate, err := cBundle.sign()
if err != nil {
return nil, err
}
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
if len(signedSSHCertificate) == 0 {
return nil, errors.New("error marshaling signed certificate")
}
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": string(signedSSHCertificate),
},
}
return response, nil
}
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) {
validPrincipals := ""
validPrincipalsRaw, ok := data.GetOk("valid_principals")
if ok {
validPrincipals = validPrincipalsRaw.(string)
} else {
validPrincipals = defaultPrincipal
}
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
// Build list of allowed Principals from template and static principalsAllowedByRole
var allowedPrincipals []string
for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) {
if role.AllowedUsersTemplate {
// Look for templating markers {{ .* }}
matched, _ := regexp.MatchString(`{{.+?}}`, principal)
if matched {
if req.EntityID != "" {
// Retrieve principal based on template + entityID from request.
templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
if err == nil {
// Template returned a principal
allowedPrincipals = append(allowedPrincipals, templatePrincipal)
} else {
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
}
}
} else {
// Static principal or err template
allowedPrincipals = append(allowedPrincipals, principal)
}
} else {
// Static principal
allowedPrincipals = append(allowedPrincipals, principal)
}
}
switch {
case len(parsedPrincipals) == 0:
// There is nothing to process
return nil, nil
case len(allowedPrincipals) == 0:
// User has requested principals to be set, but role is not configured
// with any principals
return nil, fmt.Errorf("role is not configured to allow any principals")
default:
// Role was explicitly configured to allow any principal.
if principalsAllowedByRole == "*" {
return parsedPrincipals, nil
}
for _, principal := range parsedPrincipals {
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
}
}
return parsedPrincipals, nil
}
}
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
return func(allowedPrincipals []string, validPrincipal string) bool {
for _, allowedPrincipal := range allowedPrincipals {
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
return true
}
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
return true
}
}
return false
}
}
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
requestedCertificateType := data.Get("cert_type").(string)
var certificateType uint32
switch requestedCertificateType {
case "user":
if !role.AllowUserCertificates {
return 0, errors.New("cert_type 'user' is not allowed by role")
}
certificateType = ssh.UserCert
case "host":
if !role.AllowHostCertificates {
return 0, errors.New("cert_type 'host' is not allowed by role")
}
certificateType = ssh.HostCert
default:
return 0, errors.New("cert_type must be either 'user' or 'host'")
}
return certificateType, nil
}
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
reqID := data.Get("key_id").(string)
if reqID != "" {
if !role.AllowUserKeyIDs {
return "", fmt.Errorf("setting key_id is not allowed by role")
}
return reqID, nil
}
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
if req.DisplayName == "" {
keyIDFormat = "vault-{{public_key_hash}}"
}
if role.KeyIDFormat != "" {
keyIDFormat = role.KeyIDFormat
}
keyID := substQuery(keyIDFormat, map[string]string{
"token_display_name": req.DisplayName,
"role_name": data.Get("role").(string),
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
})
return keyID, nil
}
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
if len(unparsedCriticalOptions) == 0 {
return role.DefaultCriticalOptions, nil
}
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
if role.AllowedCriticalOptions != "" {
notAllowedOptions := []string{}
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
for option := range criticalOptions {
if !strutil.StrListContains(allowedCriticalOptions, option) {
notAllowedOptions = append(notAllowedOptions, option)
}
}
if len(notAllowedOptions) != 0 {
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
}
}
return criticalOptions, nil
}
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) {
unparsedExtensions := data.Get("extensions").(map[string]interface{})
extensions := make(map[string]string)
if len(unparsedExtensions) > 0 {
extensions := convertMapToStringValue(unparsedExtensions)
if role.AllowedExtensions == "*" {
// Allowed extensions was configured to allow all
return extensions, nil
}
notAllowed := []string{}
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
for extensionKey := range extensions {
if !strutil.StrListContains(allowedExtensions, extensionKey) {
notAllowed = append(notAllowed, extensionKey)
}
}
if len(notAllowed) != 0 {
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
}
return extensions, nil
}
if role.DefaultExtensionsTemplate {
for extensionKey, extensionValue := range role.DefaultExtensions {
// Look for templating markers {{ .* }}
matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue)
if matched {
if req.EntityID != "" {
// Retrieve extension value based on template + entityID from request.
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
if err == nil {
// Template returned an extension value that we can use
extensions[extensionKey] = templateExtensionValue
} else {
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
}
}
} else {
// Static extension value or err template
extensions[extensionKey] = extensionValue
}
}
} else {
extensions = role.DefaultExtensions
}
return extensions, nil
}
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
var ttl, maxTTL time.Duration
var err error
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
}
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
}
}
return ttl, nil
}
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
if len(role.AllowedUserKeyTypesLengths) != 0 {
var keyType string
var keyBits int
switch k := publickey.(type) {
case ssh.CryptoPublicKey:
ff := k.CryptoPublicKey()
switch k := ff.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = k.N.BitLen()
case *dsa.PublicKey:
keyType = "dsa"
keyBits = k.Parameters.P.BitLen()
case *ecdsa.PublicKey:
keyType = "ecdsa"
keyBits = k.Curve.Params().BitSize
case ed25519.PublicKey:
keyType = "ed25519"
default:
return fmt.Errorf("public key type of %s is not allowed", keyType)
}
default:
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
}
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
var present bool
var pass bool
for _, kstr := range keyTypeToMapKey[keyType] {
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
if !ok {
continue
}
present = true
for _, value := range allowed_values {
if keyType == "rsa" || keyType == "dsa" {
// Regardless of map naming, we always need to validate the
// bit length of RSA and DSA keys. Use the keyType flag to
if keyBits == value {
pass = true
}
} else if kstr == "ec" || kstr == "ecdsa" {
// If the map string is "ecdsa", we have to validate the keyBits
// are a match for an allowed value, meaning that our curve
// is allowed. This isn't necessary when a named curve (e.g.
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
// because keyBits is already specified in the kstr. Thus,
// we have conditioned around kstr and not keyType (like with
// rsa or dsa).
if keyBits == value {
pass = true
}
} else {
// We get here in two cases: we have a algo-named EC key
// matching a format specifier in the key map (e.g., a P-256
// key with a KeyAlgoECDSA256 entry in the map) or we have a
// ed25519 key (which is always allowed).
pass = true
}
}
}
if !present {
return fmt.Errorf("key of type %s is not allowed", keyType)
}
if !pass {
return fmt.Errorf("key is of an invalid size: %v", keyBits)
}
}
return nil
}
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
defer func() {
if r := recover(); r != nil {
errMsg, ok := r.(string)
if ok {
retCert = nil
retErr = errors.New(errMsg)
}
}
}()
serialNumber, err := certutil.GenerateSerialNumber()
if err != nil {
return nil, err
}
now := time.Now()
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
if !ok {
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
}
// prepare certificate for signing
nonce := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
}
certificate := &ssh.Certificate{
Serial: serialNumber.Uint64(),
Key: b.PublicKey,
KeyId: b.KeyID,
ValidPrincipals: b.ValidPrincipals,
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
CertType: b.CertificateType,
Permissions: ssh.Permissions{
CriticalOptions: b.CriticalOptions,
Extensions: b.Extensions,
},
Nonce: nonce,
SignatureKey: sshAlgorithmSigner.PublicKey(),
}
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
out := certificate.Marshal()
// Drop trailing signature length.
certificateBytes := out[:len(out)-4]
algo := b.Role.AlgorithmSigner
// Handle the new default algorithm selection process correctly.
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
algo = ssh.SigAlgoRSASHA2256
} else if algo == DefaultAlgorithmSigner {
algo = ""
}
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
if err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
}
certificate.Signature = sig
return certificate, nil
}
func createKeyTypeToMapKey(keyType string, keyBits int) map[string][]string {
keyTypeToMapKey := map[string][]string{
"rsa": {"rsa", ssh.KeyAlgoRSA},
"dsa": {"dsa", ssh.KeyAlgoDSA},
"ecdsa": {"ecdsa", "ec"},
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
}
if keyType == "ecdsa" {
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
}
}
return keyTypeToMapKey
}

View File

@ -2,40 +2,12 @@ package ssh
import ( import (
"context" "context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt" "fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
) )
type creationBundle struct {
KeyID string
ValidPrincipals []string
PublicKey ssh.PublicKey
CertificateType uint32
TTL time.Duration
Signer ssh.Signer
Role *sshRole
CriticalOptions map[string]string
Extensions map[string]string
}
func pathSign(b *backend) *framework.Path { func pathSign(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "sign/" + framework.GenericNameWithAtRegex("role"), Pattern: "sign/" + framework.GenericNameWithAtRegex("role"),
@ -120,497 +92,10 @@ func (b *backend) pathSignCertificate(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(fmt.Sprintf("public_key failed to meet the key requirements: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("public_key failed to meet the key requirements: %s", err)), nil
} }
// Note that these various functions always return "user errors" so we pass response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey)
// them as 4xx values
keyID, err := b.calculateKeyID(data, req, role, userPublicKey)
if err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), nil return response, err
}
certificateType, err := b.calculateCertificateType(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
var parsedPrincipals []string
if certificateType == ssh.HostCert {
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role))
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
} else {
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
ttl, err := b.calculateTTL(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
criticalOptions, err := b.calculateCriticalOptions(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
extensions, err := b.calculateExtensions(data, req, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to read CA private key: %w", err)
}
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
return nil, fmt.Errorf("failed to read CA private key")
}
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
if err != nil {
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
}
cBundle := creationBundle{
KeyID: keyID,
PublicKey: userPublicKey,
Signer: signer,
ValidPrincipals: parsedPrincipals,
TTL: ttl,
CertificateType: certificateType,
Role: role,
CriticalOptions: criticalOptions,
Extensions: extensions,
}
certificate, err := cBundle.sign()
if err != nil {
return nil, err
}
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
if len(signedSSHCertificate) == 0 {
return nil, fmt.Errorf("error marshaling signed certificate")
}
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": string(signedSSHCertificate),
},
} }
return response, nil return response, nil
} }
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) {
validPrincipals := ""
validPrincipalsRaw, ok := data.GetOk("valid_principals")
if ok {
validPrincipals = validPrincipalsRaw.(string)
} else {
validPrincipals = defaultPrincipal
}
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
// Build list of allowed Principals from template and static principalsAllowedByRole
var allowedPrincipals []string
for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) {
if role.AllowedUsersTemplate {
// Look for templating markers {{ .* }}
matched, _ := regexp.MatchString(`{{.+?}}`, principal)
if matched {
if req.EntityID != "" {
// Retrieve principal based on template + entityID from request.
templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
if err == nil {
// Template returned a principal
allowedPrincipals = append(allowedPrincipals, templatePrincipal)
} else {
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
}
}
} else {
// Static principal or err template
allowedPrincipals = append(allowedPrincipals, principal)
}
} else {
// Static principal
allowedPrincipals = append(allowedPrincipals, principal)
}
}
switch {
case len(parsedPrincipals) == 0:
// There is nothing to process
return nil, nil
case len(allowedPrincipals) == 0:
// User has requested principals to be set, but role is not configured
// with any principals
return nil, fmt.Errorf("role is not configured to allow any principals")
default:
// Role was explicitly configured to allow any principal.
if principalsAllowedByRole == "*" {
return parsedPrincipals, nil
}
for _, principal := range parsedPrincipals {
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
}
}
return parsedPrincipals, nil
}
}
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
return func(allowedPrincipals []string, validPrincipal string) bool {
for _, allowedPrincipal := range allowedPrincipals {
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
return true
}
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
return true
}
}
return false
}
}
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
requestedCertificateType := data.Get("cert_type").(string)
var certificateType uint32
switch requestedCertificateType {
case "user":
if !role.AllowUserCertificates {
return 0, errors.New("cert_type 'user' is not allowed by role")
}
certificateType = ssh.UserCert
case "host":
if !role.AllowHostCertificates {
return 0, errors.New("cert_type 'host' is not allowed by role")
}
certificateType = ssh.HostCert
default:
return 0, errors.New("cert_type must be either 'user' or 'host'")
}
return certificateType, nil
}
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
reqID := data.Get("key_id").(string)
if reqID != "" {
if !role.AllowUserKeyIDs {
return "", fmt.Errorf("setting key_id is not allowed by role")
}
return reqID, nil
}
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
if req.DisplayName == "" {
keyIDFormat = "vault-{{public_key_hash}}"
}
if role.KeyIDFormat != "" {
keyIDFormat = role.KeyIDFormat
}
keyID := substQuery(keyIDFormat, map[string]string{
"token_display_name": req.DisplayName,
"role_name": data.Get("role").(string),
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
})
return keyID, nil
}
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
if len(unparsedCriticalOptions) == 0 {
return role.DefaultCriticalOptions, nil
}
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
if role.AllowedCriticalOptions != "" {
notAllowedOptions := []string{}
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
for option := range criticalOptions {
if !strutil.StrListContains(allowedCriticalOptions, option) {
notAllowedOptions = append(notAllowedOptions, option)
}
}
if len(notAllowedOptions) != 0 {
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
}
}
return criticalOptions, nil
}
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) {
unparsedExtensions := data.Get("extensions").(map[string]interface{})
extensions := make(map[string]string)
if len(unparsedExtensions) > 0 {
extensions := convertMapToStringValue(unparsedExtensions)
if role.AllowedExtensions == "*" {
// Allowed extensions was configured to allow all
return extensions, nil
}
notAllowed := []string{}
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
for extensionKey := range extensions {
if !strutil.StrListContains(allowedExtensions, extensionKey) {
notAllowed = append(notAllowed, extensionKey)
}
}
if len(notAllowed) != 0 {
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
}
return extensions, nil
}
if role.DefaultExtensionsTemplate {
for extensionKey, extensionValue := range role.DefaultExtensions {
// Look for templating markers {{ .* }}
matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue)
if matched {
if req.EntityID != "" {
// Retrieve extension value based on template + entityID from request.
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
if err == nil {
// Template returned an extension value that we can use
extensions[extensionKey] = templateExtensionValue
} else {
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
}
}
} else {
// Static extension value or err template
extensions[extensionKey] = extensionValue
}
}
} else {
extensions = role.DefaultExtensions
}
return extensions, nil
}
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
var ttl, maxTTL time.Duration
var err error
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
}
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
}
}
return ttl, nil
}
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
if len(role.AllowedUserKeyTypesLengths) != 0 {
var keyType string
var keyBits int
switch k := publickey.(type) {
case ssh.CryptoPublicKey:
ff := k.CryptoPublicKey()
switch k := ff.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = k.N.BitLen()
case *dsa.PublicKey:
keyType = "dsa"
keyBits = k.Parameters.P.BitLen()
case *ecdsa.PublicKey:
keyType = "ecdsa"
keyBits = k.Curve.Params().BitSize
case ed25519.PublicKey:
keyType = "ed25519"
default:
return fmt.Errorf("public key type of %s is not allowed", keyType)
}
default:
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
}
keyTypeToMapKey := map[string][]string{
"rsa": {"rsa", ssh.KeyAlgoRSA},
"dsa": {"dsa", ssh.KeyAlgoDSA},
"ecdsa": {"ecdsa", "ec"},
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
}
if keyType == "ecdsa" {
ecCurveBitsToAlgoName := map[int]string{
256: ssh.KeyAlgoECDSA256,
384: ssh.KeyAlgoECDSA384,
521: ssh.KeyAlgoECDSA521,
}
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
}
// If the algorithm is not found, it could be that we have a curve
// that we haven't added a constant for yet. But they could allow it
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
// mapping rather than using a named SSH key type, so erring out here
// isn't advisable.
}
var present bool
var pass bool
for _, kstr := range keyTypeToMapKey[keyType] {
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
if !ok {
continue
}
present = true
for _, value := range allowed_values {
if keyType == "rsa" || keyType == "dsa" {
// Regardless of map naming, we always need to validate the
// bit length of RSA and DSA keys. Use the keyType flag to
if keyBits == value {
pass = true
}
} else if kstr == "ec" || kstr == "ecdsa" {
// If the map string is "ecdsa", we have to validate the keyBits
// are a match for an allowed value, meaning that our curve
// is allowed. This isn't necessary when a named curve (e.g.
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
// because keyBits is already specified in the kstr. Thus,
// we have conditioned around kstr and not keyType (like with
// rsa or dsa).
if keyBits == value {
pass = true
}
} else {
// We get here in two cases: we have a algo-named EC key
// matching a format specifier in the key map (e.g., a P-256
// key with a KeyAlgoECDSA256 entry in the map) or we have a
// ed25519 key (which is always allowed).
pass = true
}
}
}
if !present {
return fmt.Errorf("key of type %s is not allowed", keyType)
}
if !pass {
return fmt.Errorf("key is of an invalid size: %v", keyBits)
}
}
return nil
}
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
defer func() {
if r := recover(); r != nil {
errMsg, ok := r.(string)
if ok {
retCert = nil
retErr = errors.New(errMsg)
}
}
}()
serialNumber, err := certutil.GenerateSerialNumber()
if err != nil {
return nil, err
}
now := time.Now()
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
if !ok {
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
}
// prepare certificate for signing
nonce := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
}
certificate := &ssh.Certificate{
Serial: serialNumber.Uint64(),
Key: b.PublicKey,
KeyId: b.KeyID,
ValidPrincipals: b.ValidPrincipals,
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
CertType: b.CertificateType,
Permissions: ssh.Permissions{
CriticalOptions: b.CriticalOptions,
Extensions: b.Extensions,
},
Nonce: nonce,
SignatureKey: sshAlgorithmSigner.PublicKey(),
}
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
out := certificate.Marshal()
// Drop trailing signature length.
certificateBytes := out[:len(out)-4]
algo := b.Role.AlgorithmSigner
// Handle the new default algorithm selection process correctly.
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
algo = ssh.SigAlgoRSASHA2256
} else if algo == DefaultAlgorithmSigner {
algo = ""
}
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
if err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
}
certificate.Signature = sig
return certificate, nil
}

3
changelog/15561.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
ssh: Addition of an endpoint `ssh/issue/:role` to allow the creation of signed key pairs
```

View File

@ -873,3 +873,88 @@ $ curl \
"auth": null "auth": null
} }
``` ```
## Generate Certificate and Key
This endpoint generates a new set of credentials (private key and certificate)
based on the role named in the endpoint.
~> **Note**: The private key is _not_ stored. If you do not save the private
key from the response, you will need to request a new certificate.
| Method | Path |
| :----- | :---------------- |
| `POST` | `/ssh/issue/:name` |
### Parameters
- `name` `(string: <required>)` Specifies the name of the role to create the
certificate against. This is part of the request URL.
- `key_type` `(string: "rsa")` Specifies the desired key type; must be `rsa`, `ed25519`
or `ec`.
- `key_bits` `(int: 0)` Specifies the number of bits to use for the
generated keys. Allowed values are 0 (universal default); with
`key_type=rsa`, allowed values are: 2048 (default), 3072, or
4096; with `key_type=ec`, allowed values are: 224, 256 (default),
384, or 521; ignored with `key_type=ed25519`.
- `ttl` `(string: "")` Specifies the Requested Time To Live. Cannot be greater
than the role's `max_ttl` value. If not provided, the role's `ttl` value will
be used. Note that the role values default to system values if not explicitly
set.
- `valid_principals` `(string: "")` Specifies valid principals, either
usernames or hostnames, that the certificate should be signed for.
- `cert_type` `(string: "user")` Specifies the type of certificate to be
created; either "user" or "host".
- `key_id` `(string: "")` Specifies the key id that the created certificate
should have. If not specified, the display name of the token will be used.
- `critical_options` `(map<string|string>: "")` Specifies a map of the
critical options that the certificate should be signed for. Defaults to none.
- `extensions` `(map<string|string>: "")` Specifies a map of the extensions
that the certificate should be signed for. Defaults to none.
### Sample Payload
```json
{
"key_type": "rsa",
"key_bits": 2048
}
```
### Sample Request
```shell-session
$ curl \
--header "X-Vault-Token: ..." \
--request POST \
--data @payload.json \
http://127.0.0.1:8200/v1/ssh/issue/my-role
```
### Sample Response
```json
{
"request_id": "94fd1102-08a1-c207-0e3e-657e8f80c09e",
"lease_id": "",
"renewable": false,
"lease_duration": 0,
"data": {
"serial_number": "1e965817eb12a511",
"signed_key": "ssh-rsa-cert-v01@openssh.com AAAAHHN...\n",
"private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpQIBAAKCAQEAwer03vkQrPV+wWpbisJJv2CKqHmMz+Ej0ctLbhpOmR2CY9S9\n...\nQN351pgTphi6nlCkGPzkDuwvtxSxiCWXQcaxrHAL7MiJpPzkIBq1\n-----END RSA PRIVATE KEY-----\n",
"private_key_type": "rsa"
},
"wrap_info": null,
"warnings": null,
"auth": null
}
```