diff --git a/builtin/logical/ssh/backend.go b/builtin/logical/ssh/backend.go index c7250d036..fe4f40b33 100644 --- a/builtin/logical/ssh/backend.go +++ b/builtin/logical/ssh/backend.go @@ -61,6 +61,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { pathVerify(&b), pathConfigCA(&b), pathSign(&b), + pathIssue(&b), pathFetchPublicKey(&b), }, diff --git a/builtin/logical/ssh/backend_test.go b/builtin/logical/ssh/backend_test.go index d830103e8..d762174fb 100644 --- a/builtin/logical/ssh/backend_test.go +++ b/builtin/logical/ssh/backend_test.go @@ -1776,6 +1776,50 @@ func TestSSHBackend_ValidateNotBeforeDuration(t *testing.T) { logicaltest.Test(t, testCase) } +func TestSSHBackend_IssueSign(t *testing.T) { + config := logical.TestBackendConfig() + + b, err := Factory(context.Background(), config) + if err != nil { + t.Fatalf("Cannot create backend: %s", err) + } + + testCase := logicaltest.TestCase{ + LogicalBackend: b, + Steps: []logicaltest.TestStep{ + configCaStep(testCAPublicKey, testCAPrivateKey), + + createRoleStep("testing", map[string]interface{}{ + "key_type": "otp", + "default_user": "user", + }), + // Key pair not issued with invalid role key type + issueSSHKeyPairStep("testing", "rsa", 0, true, "role key type 'otp' not allowed to issue key pairs"), + + createRoleStep("testing", map[string]interface{}{ + "key_type": "ca", + "allow_user_key_ids": false, + "allow_user_certificates": true, + "allowed_user_key_lengths": map[string]interface{}{ + "ssh-rsa": []int{2048, 3072, 4096}, + "ecdsa-sha2-nistp521": 0, + "ed25519": 0, + }, + }), + // Key_type not in allowed_user_key_types_lengths + issueSSHKeyPairStep("testing", "ec", 256, true, "provided key_type value not in allowed_user_key_types"), + // Key_bits not in allowed_user_key_types_lengths for provided key_type + issueSSHKeyPairStep("testing", "rsa", 2560, true, "provided key_bits value not in list of role's allowed_user_key_types"), + // key_type `rsa` and key_bits `2048` successfully created + issueSSHKeyPairStep("testing", "rsa", 2048, false, ""), + // key_type `ed22519` and key_bits `0` successfully created + issueSSHKeyPairStep("testing", "ed25519", 0, false, ""), + }, + } + + logicaltest.Test(t, testCase) +} + func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster, string) { coreConfig := &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{ @@ -1847,7 +1891,8 @@ func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster, } func testAllowedUsersTemplate(t *testing.T, testAllowedUsersTemplate string, - expectedValidPrincipal string, testEntityMetadata map[string]string) { + expectedValidPrincipal string, testEntityMetadata map[string]string, +) { cluster, userpassToken := getSshCaTestCluster(t, testUserName) defer cluster.Cleanup() client := cluster.Cores[0].Client @@ -1926,7 +1971,8 @@ func signCertificateStep( role, keyID string, certType int, validPrincipals []string, criticalOptionPermissions, extensionPermissions map[string]string, ttl time.Duration, - requestParameters map[string]interface{}) logicaltest.TestStep { + requestParameters map[string]interface{}, +) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "sign/" + role, @@ -1955,6 +2001,42 @@ func signCertificateStep( } } +func issueSSHKeyPairStep(role, keyType string, keyBits int, expectError bool, errorMsg string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "issue/" + role, + Data: map[string]interface{}{ + "key_type": keyType, + "key_bits": keyBits, + }, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if expectError { + var err error + if resp.Data["error"] != errorMsg { + err = fmt.Errorf("actual error message \"%s\" different from expected error message \"%s\"", resp.Data["error"], errorMsg) + } + + return err + } + + if resp.IsError() { + return fmt.Errorf("unexpected error response returned: %v", resp.Error()) + } + + if resp.Data["private_key_type"] != keyType { + return fmt.Errorf("response private_key_type (%s) does not match the provided key_type (%s)", resp.Data["private_key_type"], keyType) + } + + if resp.Data["signed_key"] == "" { + return errors.New("certificate/signed_key should not be empty") + } + + return nil + }, + } +} + func validateSSHCertificate(cert *ssh.Certificate, keyID string, certType int, validPrincipals []string, criticalOptionPermissions, extensionPermissions map[string]string, ttl time.Duration, ) error { diff --git a/builtin/logical/ssh/path_issue.go b/builtin/logical/ssh/path_issue.go new file mode 100644 index 000000000..19b57cb93 --- /dev/null +++ b/builtin/logical/ssh/path_issue.go @@ -0,0 +1,180 @@ +package ssh + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +type keySpecs struct { + Type string + Bits int +} + +func pathIssue(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "issue/" + framework.GenericNameWithAtRegex("role"), + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathIssue, + }, + }, + Fields: map[string]*framework.FieldSchema{ + "role": { + Type: framework.TypeString, + Description: `The desired role with configuration for this request.`, + }, + "key_type": { + Type: framework.TypeString, + Description: "Specifies the desired key type; must be `rsa`, `ed25519` or `ec`", + Default: "rsa", + }, + "key_bits": { + Type: framework.TypeInt, + Description: "Specifies the number of bits to use for the generated keys.", + Default: 0, + }, + "ttl": { + Type: framework.TypeDurationSecond, + Description: `The requested Time To Live for the SSH certificate; +sets the expiration date. If not specified +the role default, backend default, or system +default TTL is used, in that order. Cannot +be later than the role max TTL.`, + }, + "valid_principals": { + Type: framework.TypeString, + Description: `Valid principals, either usernames or hostnames, that the certificate should be signed for.`, + }, + "cert_type": { + Type: framework.TypeString, + Description: `Type of certificate to be created; either "user" or "host".`, + Default: "user", + }, + "key_id": { + Type: framework.TypeString, + Description: `Key id that the created certificate should have. If not specified, the display name of the token will be used.`, + }, + "critical_options": { + Type: framework.TypeMap, + Description: `Critical options that the certificate should be signed for.`, + }, + "extensions": { + Type: framework.TypeMap, + Description: `Extensions that the certificate should be signed for.`, + }, + }, + HelpSynopsis: pathIssueHelpSyn, + HelpDescription: pathIssueHelpDesc, + } +} + +func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Get the role + roleName := data.Get("role").(string) + role, err := b.getRole(ctx, req.Storage, roleName) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", roleName)), nil + } + + if role.KeyType != "ca" { + return logical.ErrorResponse("role key type '%s' not allowed to issue key pairs", role.KeyType), nil + } + + // Validate and extract key specifications + keySpecs, err := extractKeySpecs(role, data) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + // Issue certificate + return b.pathIssueCertificate(ctx, req, data, role, keySpecs) +} + +func (b *backend) pathIssueCertificate(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, keySpecs *keySpecs) (*logical.Response, error) { + publicKey, privateKey, err := generateSSHKeyPair(rand.Reader, keySpecs.Type, keySpecs.Bits) + if err != nil { + return nil, err + } + + // Sign key + userPublicKey, err := parsePublicSSHKey(publicKey) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to parse public_key as SSH key: %s", err)), nil + } + + response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + // Additional to sign response + response.Data["private_key"] = privateKey + response.Data["private_key_type"] = keySpecs.Type + + return response, nil +} + +func extractKeySpecs(role *sshRole, data *framework.FieldData) (*keySpecs, error) { + keyType := data.Get("key_type").(string) + keyBits := data.Get("key_bits").(int) + keySpecs := keySpecs{ + Type: keyType, + Bits: keyBits, + } + + keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits) + + if len(role.AllowedUserKeyTypesLengths) != 0 { + var keyAllowed bool + var bitsAllowed bool + + keyTypeAliasesLoop: + for _, keyTypeAlias := range keyTypeToMapKey[keyType] { + allowedValues, allowed := role.AllowedUserKeyTypesLengths[keyTypeAlias] + if !allowed { + continue + } + keyAllowed = true + + for _, value := range allowedValues { + if value == keyBits { + bitsAllowed = true + break keyTypeAliasesLoop + } + } + } + + if !keyAllowed { + return nil, errors.New("provided key_type value not in allowed_user_key_types") + } + + if !bitsAllowed { + return nil, errors.New("provided key_bits value not in list of role's allowed_user_key_types") + } + } + + return &keySpecs, nil +} + +const pathIssueHelpSyn = ` +Request a certificate using a certain role with the provided details. +` + +const pathIssueHelpDesc = ` +This path allows requesting a certificate to be issued according to the +policy of the given role. The certificate will only be issued if the +requested details are allowed by the role policy. + +This path returns a certificate and a private key. If you want a workflow +that does not expose a private key, generate a CSR locally and use the +sign path instead. +` diff --git a/builtin/logical/ssh/path_issue_sign.go b/builtin/logical/ssh/path_issue_sign.go new file mode 100644 index 000000000..117767d4e --- /dev/null +++ b/builtin/logical/ssh/path_issue_sign.go @@ -0,0 +1,539 @@ +package ssh + +import ( + "context" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "errors" + "fmt" + "io" + "regexp" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-secure-stdlib/parseutil" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/certutil" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/logical" + "golang.org/x/crypto/ssh" +) + +var ecCurveBitsToAlgoName = map[int]string{ + 256: ssh.KeyAlgoECDSA256, + 384: ssh.KeyAlgoECDSA384, + 521: ssh.KeyAlgoECDSA521, +} + +// If the algorithm is not found, it could be that we have a curve +// that we haven't added a constant for yet. But they could allow it +// (assuming x/crypto/ssh can parse it) via setting a ec: +// mapping rather than using a named SSH key type, so erring out here +// isn't advisable. + +type creationBundle struct { + KeyID string + ValidPrincipals []string + PublicKey ssh.PublicKey + CertificateType uint32 + TTL time.Duration + Signer ssh.Signer + Role *sshRole + CriticalOptions map[string]string + Extensions map[string]string +} + +func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, publicKey ssh.PublicKey) (*logical.Response, error) { + // Note that these various functions always return "user errors" so we pass + // them as 4xx values + keyID, err := b.calculateKeyID(data, req, role, publicKey) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + certificateType, err := b.calculateCertificateType(data, role) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + var parsedPrincipals []string + if certificateType == ssh.HostCert { + parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role)) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + } else { + parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + } + + ttl, err := b.calculateTTL(data, role) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + criticalOptions, err := b.calculateCriticalOptions(data, role) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + extensions, err := b.calculateExtensions(data, req, role) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to read CA private key: %w", err) + } + if privateKeyEntry == nil || privateKeyEntry.Key == "" { + return nil, errors.New("failed to read CA private key") + } + + signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key)) + if err != nil { + return nil, fmt.Errorf("failed to parse stored CA private key: %w", err) + } + + cBundle := creationBundle{ + KeyID: keyID, + PublicKey: publicKey, + Signer: signer, + ValidPrincipals: parsedPrincipals, + TTL: ttl, + CertificateType: certificateType, + Role: role, + CriticalOptions: criticalOptions, + Extensions: extensions, + } + + certificate, err := cBundle.sign() + if err != nil { + return nil, err + } + + signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate) + if len(signedSSHCertificate) == 0 { + return nil, errors.New("error marshaling signed certificate") + } + + response := &logical.Response{ + Data: map[string]interface{}{ + "serial_number": strconv.FormatUint(certificate.Serial, 16), + "signed_key": string(signedSSHCertificate), + }, + } + + return response, nil +} + +func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) { + validPrincipals := "" + validPrincipalsRaw, ok := data.GetOk("valid_principals") + if ok { + validPrincipals = validPrincipalsRaw.(string) + } else { + validPrincipals = defaultPrincipal + } + + parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false) + // Build list of allowed Principals from template and static principalsAllowedByRole + var allowedPrincipals []string + for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) { + if role.AllowedUsersTemplate { + // Look for templating markers {{ .* }} + matched, _ := regexp.MatchString(`{{.+?}}`, principal) + if matched { + if req.EntityID != "" { + // Retrieve principal based on template + entityID from request. + templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System()) + if err == nil { + // Template returned a principal + allowedPrincipals = append(allowedPrincipals, templatePrincipal) + } else { + return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err) + } + } + } else { + // Static principal or err template + allowedPrincipals = append(allowedPrincipals, principal) + } + } else { + // Static principal + allowedPrincipals = append(allowedPrincipals, principal) + } + } + + switch { + case len(parsedPrincipals) == 0: + // There is nothing to process + return nil, nil + case len(allowedPrincipals) == 0: + // User has requested principals to be set, but role is not configured + // with any principals + return nil, fmt.Errorf("role is not configured to allow any principals") + default: + // Role was explicitly configured to allow any principal. + if principalsAllowedByRole == "*" { + return parsedPrincipals, nil + } + + for _, principal := range parsedPrincipals { + if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) { + return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal) + } + } + return parsedPrincipals, nil + } +} + +func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool { + return func(allowedPrincipals []string, validPrincipal string) bool { + for _, allowedPrincipal := range allowedPrincipals { + if allowedPrincipal == validPrincipal && role.AllowBareDomains { + return true + } + if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) { + return true + } + } + + return false + } +} + +func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) { + requestedCertificateType := data.Get("cert_type").(string) + + var certificateType uint32 + switch requestedCertificateType { + case "user": + if !role.AllowUserCertificates { + return 0, errors.New("cert_type 'user' is not allowed by role") + } + certificateType = ssh.UserCert + case "host": + if !role.AllowHostCertificates { + return 0, errors.New("cert_type 'host' is not allowed by role") + } + certificateType = ssh.HostCert + default: + return 0, errors.New("cert_type must be either 'user' or 'host'") + } + + return certificateType, nil +} + +func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) { + reqID := data.Get("key_id").(string) + + if reqID != "" { + if !role.AllowUserKeyIDs { + return "", fmt.Errorf("setting key_id is not allowed by role") + } + return reqID, nil + } + + keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}" + if req.DisplayName == "" { + keyIDFormat = "vault-{{public_key_hash}}" + } + + if role.KeyIDFormat != "" { + keyIDFormat = role.KeyIDFormat + } + + keyID := substQuery(keyIDFormat, map[string]string{ + "token_display_name": req.DisplayName, + "role_name": data.Get("role").(string), + "public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())), + }) + + return keyID, nil +} + +func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) { + unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{}) + if len(unparsedCriticalOptions) == 0 { + return role.DefaultCriticalOptions, nil + } + + criticalOptions := convertMapToStringValue(unparsedCriticalOptions) + + if role.AllowedCriticalOptions != "" { + notAllowedOptions := []string{} + allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",") + + for option := range criticalOptions { + if !strutil.StrListContains(allowedCriticalOptions, option) { + notAllowedOptions = append(notAllowedOptions, option) + } + } + + if len(notAllowedOptions) != 0 { + return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions) + } + } + + return criticalOptions, nil +} + +func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) { + unparsedExtensions := data.Get("extensions").(map[string]interface{}) + extensions := make(map[string]string) + + if len(unparsedExtensions) > 0 { + extensions := convertMapToStringValue(unparsedExtensions) + if role.AllowedExtensions == "*" { + // Allowed extensions was configured to allow all + return extensions, nil + } + + notAllowed := []string{} + allowedExtensions := strings.Split(role.AllowedExtensions, ",") + for extensionKey := range extensions { + if !strutil.StrListContains(allowedExtensions, extensionKey) { + notAllowed = append(notAllowed, extensionKey) + } + } + + if len(notAllowed) != 0 { + return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed) + } + return extensions, nil + } + + if role.DefaultExtensionsTemplate { + for extensionKey, extensionValue := range role.DefaultExtensions { + // Look for templating markers {{ .* }} + matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue) + if matched { + if req.EntityID != "" { + // Retrieve extension value based on template + entityID from request. + templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System()) + if err == nil { + // Template returned an extension value that we can use + extensions[extensionKey] = templateExtensionValue + } else { + return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err) + } + } + } else { + // Static extension value or err template + extensions[extensionKey] = extensionValue + } + } + } else { + extensions = role.DefaultExtensions + } + + return extensions, nil +} + +func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) { + var ttl, maxTTL time.Duration + var err error + + ttlRaw, specifiedTTL := data.GetOk("ttl") + if specifiedTTL { + ttl = time.Duration(ttlRaw.(int)) * time.Second + } else { + ttl, err = parseutil.ParseDurationSecond(role.TTL) + if err != nil { + return 0, err + } + } + if ttl == 0 { + ttl = b.System().DefaultLeaseTTL() + } + + maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL) + if err != nil { + return 0, err + } + if maxTTL == 0 { + maxTTL = b.System().MaxLeaseTTL() + } + + if ttl > maxTTL { + // Don't error if they were using system defaults, only error if + // they specifically chose a bad TTL + if !specifiedTTL { + ttl = maxTTL + } else { + return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second) + } + } + + return ttl, nil +} + +func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error { + if len(role.AllowedUserKeyTypesLengths) != 0 { + var keyType string + var keyBits int + + switch k := publickey.(type) { + case ssh.CryptoPublicKey: + ff := k.CryptoPublicKey() + switch k := ff.(type) { + case *rsa.PublicKey: + keyType = "rsa" + keyBits = k.N.BitLen() + case *dsa.PublicKey: + keyType = "dsa" + keyBits = k.Parameters.P.BitLen() + case *ecdsa.PublicKey: + keyType = "ecdsa" + keyBits = k.Curve.Params().BitSize + case ed25519.PublicKey: + keyType = "ed25519" + default: + return fmt.Errorf("public key type of %s is not allowed", keyType) + } + default: + return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k) + } + + keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits) + + var present bool + var pass bool + for _, kstr := range keyTypeToMapKey[keyType] { + allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr] + if !ok { + continue + } + + present = true + + for _, value := range allowed_values { + if keyType == "rsa" || keyType == "dsa" { + // Regardless of map naming, we always need to validate the + // bit length of RSA and DSA keys. Use the keyType flag to + if keyBits == value { + pass = true + } + } else if kstr == "ec" || kstr == "ecdsa" { + // If the map string is "ecdsa", we have to validate the keyBits + // are a match for an allowed value, meaning that our curve + // is allowed. This isn't necessary when a named curve (e.g. + // ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that), + // because keyBits is already specified in the kstr. Thus, + // we have conditioned around kstr and not keyType (like with + // rsa or dsa). + if keyBits == value { + pass = true + } + } else { + // We get here in two cases: we have a algo-named EC key + // matching a format specifier in the key map (e.g., a P-256 + // key with a KeyAlgoECDSA256 entry in the map) or we have a + // ed25519 key (which is always allowed). + pass = true + } + } + } + + if !present { + return fmt.Errorf("key of type %s is not allowed", keyType) + } + + if !pass { + return fmt.Errorf("key is of an invalid size: %v", keyBits) + } + } + return nil +} + +func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) { + defer func() { + if r := recover(); r != nil { + errMsg, ok := r.(string) + if ok { + retCert = nil + retErr = errors.New(errMsg) + } + } + }() + + serialNumber, err := certutil.GenerateSerialNumber() + if err != nil { + return nil, err + } + + now := time.Now() + + sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner) + if !ok { + return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner") + } + + // prepare certificate for signing + nonce := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce") + } + certificate := &ssh.Certificate{ + Serial: serialNumber.Uint64(), + Key: b.PublicKey, + KeyId: b.KeyID, + ValidPrincipals: b.ValidPrincipals, + ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()), + ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()), + CertType: b.CertificateType, + Permissions: ssh.Permissions{ + CriticalOptions: b.CriticalOptions, + Extensions: b.Extensions, + }, + Nonce: nonce, + SignatureKey: sshAlgorithmSigner.PublicKey(), + } + + // get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib + out := certificate.Marshal() + // Drop trailing signature length. + certificateBytes := out[:len(out)-4] + + algo := b.Role.AlgorithmSigner + + // Handle the new default algorithm selection process correctly. + if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA { + algo = ssh.SigAlgoRSASHA2256 + } else if algo == DefaultAlgorithmSigner { + algo = "" + } + + sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo) + if err != nil { + return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err) + } + + certificate.Signature = sig + + return certificate, nil +} + +func createKeyTypeToMapKey(keyType string, keyBits int) map[string][]string { + keyTypeToMapKey := map[string][]string{ + "rsa": {"rsa", ssh.KeyAlgoRSA}, + "dsa": {"dsa", ssh.KeyAlgoDSA}, + "ecdsa": {"ecdsa", "ec"}, + "ed25519": {"ed25519", ssh.KeyAlgoED25519}, + } + + if keyType == "ecdsa" { + if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok { + keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo) + } + } + + return keyTypeToMapKey +} diff --git a/builtin/logical/ssh/path_sign.go b/builtin/logical/ssh/path_sign.go index 39d384055..afef41280 100644 --- a/builtin/logical/ssh/path_sign.go +++ b/builtin/logical/ssh/path_sign.go @@ -2,40 +2,12 @@ package ssh import ( "context" - "crypto/dsa" - "crypto/ecdsa" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "errors" "fmt" - "io" - "regexp" - "strconv" - "strings" - "time" - "github.com/hashicorp/go-secure-stdlib/parseutil" - "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/framework" - "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/logical" - "golang.org/x/crypto/ed25519" - "golang.org/x/crypto/ssh" ) -type creationBundle struct { - KeyID string - ValidPrincipals []string - PublicKey ssh.PublicKey - CertificateType uint32 - TTL time.Duration - Signer ssh.Signer - Role *sshRole - CriticalOptions map[string]string - Extensions map[string]string -} - func pathSign(b *backend) *framework.Path { return &framework.Path{ Pattern: "sign/" + framework.GenericNameWithAtRegex("role"), @@ -120,497 +92,10 @@ func (b *backend) pathSignCertificate(ctx context.Context, req *logical.Request, return logical.ErrorResponse(fmt.Sprintf("public_key failed to meet the key requirements: %s", err)), nil } - // Note that these various functions always return "user errors" so we pass - // them as 4xx values - keyID, err := b.calculateKeyID(data, req, role, userPublicKey) + response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey) if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - certificateType, err := b.calculateCertificateType(data, role) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - var parsedPrincipals []string - if certificateType == ssh.HostCert { - parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role)) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - } else { - parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - } - - ttl, err := b.calculateTTL(data, role) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - criticalOptions, err := b.calculateCriticalOptions(data, role) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - extensions, err := b.calculateExtensions(data, req, role) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey) - if err != nil { - return nil, fmt.Errorf("failed to read CA private key: %w", err) - } - if privateKeyEntry == nil || privateKeyEntry.Key == "" { - return nil, fmt.Errorf("failed to read CA private key") - } - - signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key)) - if err != nil { - return nil, fmt.Errorf("failed to parse stored CA private key: %w", err) - } - - cBundle := creationBundle{ - KeyID: keyID, - PublicKey: userPublicKey, - Signer: signer, - ValidPrincipals: parsedPrincipals, - TTL: ttl, - CertificateType: certificateType, - Role: role, - CriticalOptions: criticalOptions, - Extensions: extensions, - } - - certificate, err := cBundle.sign() - if err != nil { - return nil, err - } - - signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate) - if len(signedSSHCertificate) == 0 { - return nil, fmt.Errorf("error marshaling signed certificate") - } - - response := &logical.Response{ - Data: map[string]interface{}{ - "serial_number": strconv.FormatUint(certificate.Serial, 16), - "signed_key": string(signedSSHCertificate), - }, + return response, err } return response, nil } - -func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) { - validPrincipals := "" - validPrincipalsRaw, ok := data.GetOk("valid_principals") - if ok { - validPrincipals = validPrincipalsRaw.(string) - } else { - validPrincipals = defaultPrincipal - } - - parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false) - // Build list of allowed Principals from template and static principalsAllowedByRole - var allowedPrincipals []string - for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) { - if role.AllowedUsersTemplate { - // Look for templating markers {{ .* }} - matched, _ := regexp.MatchString(`{{.+?}}`, principal) - if matched { - if req.EntityID != "" { - // Retrieve principal based on template + entityID from request. - templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System()) - if err == nil { - // Template returned a principal - allowedPrincipals = append(allowedPrincipals, templatePrincipal) - } else { - return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err) - } - } - } else { - // Static principal or err template - allowedPrincipals = append(allowedPrincipals, principal) - } - } else { - // Static principal - allowedPrincipals = append(allowedPrincipals, principal) - } - } - - switch { - case len(parsedPrincipals) == 0: - // There is nothing to process - return nil, nil - case len(allowedPrincipals) == 0: - // User has requested principals to be set, but role is not configured - // with any principals - return nil, fmt.Errorf("role is not configured to allow any principals") - default: - // Role was explicitly configured to allow any principal. - if principalsAllowedByRole == "*" { - return parsedPrincipals, nil - } - - for _, principal := range parsedPrincipals { - if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) { - return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal) - } - } - return parsedPrincipals, nil - } -} - -func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool { - return func(allowedPrincipals []string, validPrincipal string) bool { - for _, allowedPrincipal := range allowedPrincipals { - if allowedPrincipal == validPrincipal && role.AllowBareDomains { - return true - } - if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) { - return true - } - } - - return false - } -} - -func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) { - requestedCertificateType := data.Get("cert_type").(string) - - var certificateType uint32 - switch requestedCertificateType { - case "user": - if !role.AllowUserCertificates { - return 0, errors.New("cert_type 'user' is not allowed by role") - } - certificateType = ssh.UserCert - case "host": - if !role.AllowHostCertificates { - return 0, errors.New("cert_type 'host' is not allowed by role") - } - certificateType = ssh.HostCert - default: - return 0, errors.New("cert_type must be either 'user' or 'host'") - } - - return certificateType, nil -} - -func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) { - reqID := data.Get("key_id").(string) - - if reqID != "" { - if !role.AllowUserKeyIDs { - return "", fmt.Errorf("setting key_id is not allowed by role") - } - return reqID, nil - } - - keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}" - if req.DisplayName == "" { - keyIDFormat = "vault-{{public_key_hash}}" - } - - if role.KeyIDFormat != "" { - keyIDFormat = role.KeyIDFormat - } - - keyID := substQuery(keyIDFormat, map[string]string{ - "token_display_name": req.DisplayName, - "role_name": data.Get("role").(string), - "public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())), - }) - - return keyID, nil -} - -func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) { - unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{}) - if len(unparsedCriticalOptions) == 0 { - return role.DefaultCriticalOptions, nil - } - - criticalOptions := convertMapToStringValue(unparsedCriticalOptions) - - if role.AllowedCriticalOptions != "" { - notAllowedOptions := []string{} - allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",") - - for option := range criticalOptions { - if !strutil.StrListContains(allowedCriticalOptions, option) { - notAllowedOptions = append(notAllowedOptions, option) - } - } - - if len(notAllowedOptions) != 0 { - return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions) - } - } - - return criticalOptions, nil -} - -func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) { - unparsedExtensions := data.Get("extensions").(map[string]interface{}) - extensions := make(map[string]string) - - if len(unparsedExtensions) > 0 { - extensions := convertMapToStringValue(unparsedExtensions) - if role.AllowedExtensions == "*" { - // Allowed extensions was configured to allow all - return extensions, nil - } - - notAllowed := []string{} - allowedExtensions := strings.Split(role.AllowedExtensions, ",") - for extensionKey := range extensions { - if !strutil.StrListContains(allowedExtensions, extensionKey) { - notAllowed = append(notAllowed, extensionKey) - } - } - - if len(notAllowed) != 0 { - return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed) - } - return extensions, nil - } - - if role.DefaultExtensionsTemplate { - for extensionKey, extensionValue := range role.DefaultExtensions { - // Look for templating markers {{ .* }} - matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue) - if matched { - if req.EntityID != "" { - // Retrieve extension value based on template + entityID from request. - templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System()) - if err == nil { - // Template returned an extension value that we can use - extensions[extensionKey] = templateExtensionValue - } else { - return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err) - } - } - } else { - // Static extension value or err template - extensions[extensionKey] = extensionValue - } - } - } else { - extensions = role.DefaultExtensions - } - - return extensions, nil -} - -func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) { - var ttl, maxTTL time.Duration - var err error - - ttlRaw, specifiedTTL := data.GetOk("ttl") - if specifiedTTL { - ttl = time.Duration(ttlRaw.(int)) * time.Second - } else { - ttl, err = parseutil.ParseDurationSecond(role.TTL) - if err != nil { - return 0, err - } - } - if ttl == 0 { - ttl = b.System().DefaultLeaseTTL() - } - - maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL) - if err != nil { - return 0, err - } - if maxTTL == 0 { - maxTTL = b.System().MaxLeaseTTL() - } - - if ttl > maxTTL { - // Don't error if they were using system defaults, only error if - // they specifically chose a bad TTL - if !specifiedTTL { - ttl = maxTTL - } else { - return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second) - } - } - - return ttl, nil -} - -func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error { - if len(role.AllowedUserKeyTypesLengths) != 0 { - var keyType string - var keyBits int - - switch k := publickey.(type) { - case ssh.CryptoPublicKey: - ff := k.CryptoPublicKey() - switch k := ff.(type) { - case *rsa.PublicKey: - keyType = "rsa" - keyBits = k.N.BitLen() - case *dsa.PublicKey: - keyType = "dsa" - keyBits = k.Parameters.P.BitLen() - case *ecdsa.PublicKey: - keyType = "ecdsa" - keyBits = k.Curve.Params().BitSize - case ed25519.PublicKey: - keyType = "ed25519" - default: - return fmt.Errorf("public key type of %s is not allowed", keyType) - } - default: - return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k) - } - - keyTypeToMapKey := map[string][]string{ - "rsa": {"rsa", ssh.KeyAlgoRSA}, - "dsa": {"dsa", ssh.KeyAlgoDSA}, - "ecdsa": {"ecdsa", "ec"}, - "ed25519": {"ed25519", ssh.KeyAlgoED25519}, - } - - if keyType == "ecdsa" { - ecCurveBitsToAlgoName := map[int]string{ - 256: ssh.KeyAlgoECDSA256, - 384: ssh.KeyAlgoECDSA384, - 521: ssh.KeyAlgoECDSA521, - } - - if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok { - keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo) - } - - // If the algorithm is not found, it could be that we have a curve - // that we haven't added a constant for yet. But they could allow it - // (assuming x/crypto/ssh can parse it) via setting a ec: - // mapping rather than using a named SSH key type, so erring out here - // isn't advisable. - } - - var present bool - var pass bool - for _, kstr := range keyTypeToMapKey[keyType] { - allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr] - if !ok { - continue - } - - present = true - - for _, value := range allowed_values { - if keyType == "rsa" || keyType == "dsa" { - // Regardless of map naming, we always need to validate the - // bit length of RSA and DSA keys. Use the keyType flag to - if keyBits == value { - pass = true - } - } else if kstr == "ec" || kstr == "ecdsa" { - // If the map string is "ecdsa", we have to validate the keyBits - // are a match for an allowed value, meaning that our curve - // is allowed. This isn't necessary when a named curve (e.g. - // ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that), - // because keyBits is already specified in the kstr. Thus, - // we have conditioned around kstr and not keyType (like with - // rsa or dsa). - if keyBits == value { - pass = true - } - } else { - // We get here in two cases: we have a algo-named EC key - // matching a format specifier in the key map (e.g., a P-256 - // key with a KeyAlgoECDSA256 entry in the map) or we have a - // ed25519 key (which is always allowed). - pass = true - } - } - } - - if !present { - return fmt.Errorf("key of type %s is not allowed", keyType) - } - - if !pass { - return fmt.Errorf("key is of an invalid size: %v", keyBits) - } - } - return nil -} - -func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) { - defer func() { - if r := recover(); r != nil { - errMsg, ok := r.(string) - if ok { - retCert = nil - retErr = errors.New(errMsg) - } - } - }() - - serialNumber, err := certutil.GenerateSerialNumber() - if err != nil { - return nil, err - } - - now := time.Now() - - sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner) - if !ok { - return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner") - } - - // prepare certificate for signing - nonce := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce") - } - certificate := &ssh.Certificate{ - Serial: serialNumber.Uint64(), - Key: b.PublicKey, - KeyId: b.KeyID, - ValidPrincipals: b.ValidPrincipals, - ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()), - ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()), - CertType: b.CertificateType, - Permissions: ssh.Permissions{ - CriticalOptions: b.CriticalOptions, - Extensions: b.Extensions, - }, - Nonce: nonce, - SignatureKey: sshAlgorithmSigner.PublicKey(), - } - - // get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib - out := certificate.Marshal() - // Drop trailing signature length. - certificateBytes := out[:len(out)-4] - - algo := b.Role.AlgorithmSigner - - // Handle the new default algorithm selection process correctly. - if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA { - algo = ssh.SigAlgoRSASHA2256 - } else if algo == DefaultAlgorithmSigner { - algo = "" - } - - sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo) - if err != nil { - return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err) - } - - certificate.Signature = sig - - return certificate, nil -} diff --git a/changelog/15561.txt b/changelog/15561.txt new file mode 100644 index 000000000..95787b654 --- /dev/null +++ b/changelog/15561.txt @@ -0,0 +1,3 @@ +```release-note:improvement +ssh: Addition of an endpoint `ssh/issue/:role` to allow the creation of signed key pairs +``` diff --git a/website/content/api-docs/secret/ssh.mdx b/website/content/api-docs/secret/ssh.mdx index ae6762a7a..3ca424563 100644 --- a/website/content/api-docs/secret/ssh.mdx +++ b/website/content/api-docs/secret/ssh.mdx @@ -873,3 +873,88 @@ $ curl \ "auth": null } ``` + +## Generate Certificate and Key + +This endpoint generates a new set of credentials (private key and certificate) +based on the role named in the endpoint. + +~> **Note**: The private key is _not_ stored. If you do not save the private + key from the response, you will need to request a new certificate. + +| Method | Path | +| :----- | :---------------- | +| `POST` | `/ssh/issue/:name` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to create the + certificate against. This is part of the request URL. + +- `key_type` `(string: "rsa")` – Specifies the desired key type; must be `rsa`, `ed25519` + or `ec`. + +- `key_bits` `(int: 0)` – Specifies the number of bits to use for the + generated keys. Allowed values are 0 (universal default); with + `key_type=rsa`, allowed values are: 2048 (default), 3072, or + 4096; with `key_type=ec`, allowed values are: 224, 256 (default), + 384, or 521; ignored with `key_type=ed25519`. + +- `ttl` `(string: "")` – Specifies the Requested Time To Live. Cannot be greater + than the role's `max_ttl` value. If not provided, the role's `ttl` value will + be used. Note that the role values default to system values if not explicitly + set. + +- `valid_principals` `(string: "")` – Specifies valid principals, either + usernames or hostnames, that the certificate should be signed for. + +- `cert_type` `(string: "user")` – Specifies the type of certificate to be + created; either "user" or "host". + +- `key_id` `(string: "")` – Specifies the key id that the created certificate + should have. If not specified, the display name of the token will be used. + +- `critical_options` `(map: "")` – Specifies a map of the + critical options that the certificate should be signed for. Defaults to none. + +- `extensions` `(map: "")` – Specifies a map of the extensions + that the certificate should be signed for. Defaults to none. + +### Sample Payload + +```json +{ + "key_type": "rsa", + "key_bits": 2048 +} +``` + +### Sample Request + +```shell-session +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + http://127.0.0.1:8200/v1/ssh/issue/my-role +``` + +### Sample Response + +```json +{ + "request_id": "94fd1102-08a1-c207-0e3e-657e8f80c09e", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "serial_number": "1e965817eb12a511", + "signed_key": "ssh-rsa-cert-v01@openssh.com AAAAHHN...\n", + "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpQIBAAKCAQEAwer03vkQrPV+wWpbisJJv2CKqHmMz+Ej0ctLbhpOmR2CY9S9\n...\nQN351pgTphi6nlCkGPzkDuwvtxSxiCWXQcaxrHAL7MiJpPzkIBq1\n-----END RSA PRIVATE KEY-----\n", + "private_key_type": "rsa" + }, + "wrap_info": null, + "warnings": null, + "auth": null +} +```