SSH secrets engine - Enabled creation of key pairs (CA Mode) (#15561)
* Handle func * Update - check if key_type and key_bits are allowed * Update - fields * Generating keys based on provided key_type and key_bits * Returning signed key * Refactor * Refactor update to common logic function * Descriptions * Tests added * Suggested changes and tests added and refactored * Suggested changes and fmt run * File refactoring * Changelog file * Update changelog/15561.txt Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com> * Suggested changes - consistent returns and additional info to test messages * ssh issue key pair documentation Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com>
This commit is contained in:
parent
17eed2a814
commit
57eeb33faa
|
@ -61,6 +61,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||
pathVerify(&b),
|
||||
pathConfigCA(&b),
|
||||
pathSign(&b),
|
||||
pathIssue(&b),
|
||||
pathFetchPublicKey(&b),
|
||||
},
|
||||
|
||||
|
|
|
@ -1776,6 +1776,50 @@ func TestSSHBackend_ValidateNotBeforeDuration(t *testing.T) {
|
|||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
|
||||
func TestSSHBackend_IssueSign(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot create backend: %s", err)
|
||||
}
|
||||
|
||||
testCase := logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
configCaStep(testCAPublicKey, testCAPrivateKey),
|
||||
|
||||
createRoleStep("testing", map[string]interface{}{
|
||||
"key_type": "otp",
|
||||
"default_user": "user",
|
||||
}),
|
||||
// Key pair not issued with invalid role key type
|
||||
issueSSHKeyPairStep("testing", "rsa", 0, true, "role key type 'otp' not allowed to issue key pairs"),
|
||||
|
||||
createRoleStep("testing", map[string]interface{}{
|
||||
"key_type": "ca",
|
||||
"allow_user_key_ids": false,
|
||||
"allow_user_certificates": true,
|
||||
"allowed_user_key_lengths": map[string]interface{}{
|
||||
"ssh-rsa": []int{2048, 3072, 4096},
|
||||
"ecdsa-sha2-nistp521": 0,
|
||||
"ed25519": 0,
|
||||
},
|
||||
}),
|
||||
// Key_type not in allowed_user_key_types_lengths
|
||||
issueSSHKeyPairStep("testing", "ec", 256, true, "provided key_type value not in allowed_user_key_types"),
|
||||
// Key_bits not in allowed_user_key_types_lengths for provided key_type
|
||||
issueSSHKeyPairStep("testing", "rsa", 2560, true, "provided key_bits value not in list of role's allowed_user_key_types"),
|
||||
// key_type `rsa` and key_bits `2048` successfully created
|
||||
issueSSHKeyPairStep("testing", "rsa", 2048, false, ""),
|
||||
// key_type `ed22519` and key_bits `0` successfully created
|
||||
issueSSHKeyPairStep("testing", "ed25519", 0, false, ""),
|
||||
},
|
||||
}
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
|
||||
func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster, string) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
|
@ -1847,7 +1891,8 @@ func getSshCaTestCluster(t *testing.T, userIdentity string) (*vault.TestCluster,
|
|||
}
|
||||
|
||||
func testAllowedUsersTemplate(t *testing.T, testAllowedUsersTemplate string,
|
||||
expectedValidPrincipal string, testEntityMetadata map[string]string) {
|
||||
expectedValidPrincipal string, testEntityMetadata map[string]string,
|
||||
) {
|
||||
cluster, userpassToken := getSshCaTestCluster(t, testUserName)
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
@ -1926,7 +1971,8 @@ func signCertificateStep(
|
|||
role, keyID string, certType int, validPrincipals []string,
|
||||
criticalOptionPermissions, extensionPermissions map[string]string,
|
||||
ttl time.Duration,
|
||||
requestParameters map[string]interface{}) logicaltest.TestStep {
|
||||
requestParameters map[string]interface{},
|
||||
) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sign/" + role,
|
||||
|
@ -1955,6 +2001,42 @@ func signCertificateStep(
|
|||
}
|
||||
}
|
||||
|
||||
func issueSSHKeyPairStep(role, keyType string, keyBits int, expectError bool, errorMsg string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "issue/" + role,
|
||||
Data: map[string]interface{}{
|
||||
"key_type": keyType,
|
||||
"key_bits": keyBits,
|
||||
},
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if expectError {
|
||||
var err error
|
||||
if resp.Data["error"] != errorMsg {
|
||||
err = fmt.Errorf("actual error message \"%s\" different from expected error message \"%s\"", resp.Data["error"], errorMsg)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("unexpected error response returned: %v", resp.Error())
|
||||
}
|
||||
|
||||
if resp.Data["private_key_type"] != keyType {
|
||||
return fmt.Errorf("response private_key_type (%s) does not match the provided key_type (%s)", resp.Data["private_key_type"], keyType)
|
||||
}
|
||||
|
||||
if resp.Data["signed_key"] == "" {
|
||||
return errors.New("certificate/signed_key should not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validateSSHCertificate(cert *ssh.Certificate, keyID string, certType int, validPrincipals []string, criticalOptionPermissions, extensionPermissions map[string]string,
|
||||
ttl time.Duration,
|
||||
) error {
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
type keySpecs struct {
|
||||
Type string
|
||||
Bits int
|
||||
}
|
||||
|
||||
func pathIssue(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "issue/" + framework.GenericNameWithAtRegex("role"),
|
||||
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: b.pathIssue,
|
||||
},
|
||||
},
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"role": {
|
||||
Type: framework.TypeString,
|
||||
Description: `The desired role with configuration for this request.`,
|
||||
},
|
||||
"key_type": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Specifies the desired key type; must be `rsa`, `ed25519` or `ec`",
|
||||
Default: "rsa",
|
||||
},
|
||||
"key_bits": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "Specifies the number of bits to use for the generated keys.",
|
||||
Default: 0,
|
||||
},
|
||||
"ttl": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `The requested Time To Live for the SSH certificate;
|
||||
sets the expiration date. If not specified
|
||||
the role default, backend default, or system
|
||||
default TTL is used, in that order. Cannot
|
||||
be later than the role max TTL.`,
|
||||
},
|
||||
"valid_principals": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Valid principals, either usernames or hostnames, that the certificate should be signed for.`,
|
||||
},
|
||||
"cert_type": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Type of certificate to be created; either "user" or "host".`,
|
||||
Default: "user",
|
||||
},
|
||||
"key_id": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Key id that the created certificate should have. If not specified, the display name of the token will be used.`,
|
||||
},
|
||||
"critical_options": {
|
||||
Type: framework.TypeMap,
|
||||
Description: `Critical options that the certificate should be signed for.`,
|
||||
},
|
||||
"extensions": {
|
||||
Type: framework.TypeMap,
|
||||
Description: `Extensions that the certificate should be signed for.`,
|
||||
},
|
||||
},
|
||||
HelpSynopsis: pathIssueHelpSyn,
|
||||
HelpDescription: pathIssueHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the role
|
||||
roleName := data.Get("role").(string)
|
||||
role, err := b.getRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", roleName)), nil
|
||||
}
|
||||
|
||||
if role.KeyType != "ca" {
|
||||
return logical.ErrorResponse("role key type '%s' not allowed to issue key pairs", role.KeyType), nil
|
||||
}
|
||||
|
||||
// Validate and extract key specifications
|
||||
keySpecs, err := extractKeySpecs(role, data)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Issue certificate
|
||||
return b.pathIssueCertificate(ctx, req, data, role, keySpecs)
|
||||
}
|
||||
|
||||
func (b *backend) pathIssueCertificate(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, keySpecs *keySpecs) (*logical.Response, error) {
|
||||
publicKey, privateKey, err := generateSSHKeyPair(rand.Reader, keySpecs.Type, keySpecs.Bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sign key
|
||||
userPublicKey, err := parsePublicSSHKey(publicKey)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to parse public_key as SSH key: %s", err)), nil
|
||||
}
|
||||
|
||||
response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Additional to sign response
|
||||
response.Data["private_key"] = privateKey
|
||||
response.Data["private_key_type"] = keySpecs.Type
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func extractKeySpecs(role *sshRole, data *framework.FieldData) (*keySpecs, error) {
|
||||
keyType := data.Get("key_type").(string)
|
||||
keyBits := data.Get("key_bits").(int)
|
||||
keySpecs := keySpecs{
|
||||
Type: keyType,
|
||||
Bits: keyBits,
|
||||
}
|
||||
|
||||
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
|
||||
|
||||
if len(role.AllowedUserKeyTypesLengths) != 0 {
|
||||
var keyAllowed bool
|
||||
var bitsAllowed bool
|
||||
|
||||
keyTypeAliasesLoop:
|
||||
for _, keyTypeAlias := range keyTypeToMapKey[keyType] {
|
||||
allowedValues, allowed := role.AllowedUserKeyTypesLengths[keyTypeAlias]
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
keyAllowed = true
|
||||
|
||||
for _, value := range allowedValues {
|
||||
if value == keyBits {
|
||||
bitsAllowed = true
|
||||
break keyTypeAliasesLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !keyAllowed {
|
||||
return nil, errors.New("provided key_type value not in allowed_user_key_types")
|
||||
}
|
||||
|
||||
if !bitsAllowed {
|
||||
return nil, errors.New("provided key_bits value not in list of role's allowed_user_key_types")
|
||||
}
|
||||
}
|
||||
|
||||
return &keySpecs, nil
|
||||
}
|
||||
|
||||
const pathIssueHelpSyn = `
|
||||
Request a certificate using a certain role with the provided details.
|
||||
`
|
||||
|
||||
const pathIssueHelpDesc = `
|
||||
This path allows requesting a certificate to be issued according to the
|
||||
policy of the given role. The certificate will only be issued if the
|
||||
requested details are allowed by the role policy.
|
||||
|
||||
This path returns a certificate and a private key. If you want a workflow
|
||||
that does not expose a private key, generate a CSR locally and use the
|
||||
sign path instead.
|
||||
`
|
|
@ -0,0 +1,539 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var ecCurveBitsToAlgoName = map[int]string{
|
||||
256: ssh.KeyAlgoECDSA256,
|
||||
384: ssh.KeyAlgoECDSA384,
|
||||
521: ssh.KeyAlgoECDSA521,
|
||||
}
|
||||
|
||||
// If the algorithm is not found, it could be that we have a curve
|
||||
// that we haven't added a constant for yet. But they could allow it
|
||||
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
|
||||
// mapping rather than using a named SSH key type, so erring out here
|
||||
// isn't advisable.
|
||||
|
||||
type creationBundle struct {
|
||||
KeyID string
|
||||
ValidPrincipals []string
|
||||
PublicKey ssh.PublicKey
|
||||
CertificateType uint32
|
||||
TTL time.Duration
|
||||
Signer ssh.Signer
|
||||
Role *sshRole
|
||||
CriticalOptions map[string]string
|
||||
Extensions map[string]string
|
||||
}
|
||||
|
||||
func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logical.Request, data *framework.FieldData, role *sshRole, publicKey ssh.PublicKey) (*logical.Response, error) {
|
||||
// Note that these various functions always return "user errors" so we pass
|
||||
// them as 4xx values
|
||||
keyID, err := b.calculateKeyID(data, req, role, publicKey)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
certificateType, err := b.calculateCertificateType(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
var parsedPrincipals []string
|
||||
if certificateType == ssh.HostCert {
|
||||
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
} else {
|
||||
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
}
|
||||
|
||||
ttl, err := b.calculateTTL(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
criticalOptions, err := b.calculateCriticalOptions(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
extensions, err := b.calculateExtensions(data, req, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA private key: %w", err)
|
||||
}
|
||||
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
|
||||
return nil, errors.New("failed to read CA private key")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
|
||||
}
|
||||
|
||||
cBundle := creationBundle{
|
||||
KeyID: keyID,
|
||||
PublicKey: publicKey,
|
||||
Signer: signer,
|
||||
ValidPrincipals: parsedPrincipals,
|
||||
TTL: ttl,
|
||||
CertificateType: certificateType,
|
||||
Role: role,
|
||||
CriticalOptions: criticalOptions,
|
||||
Extensions: extensions,
|
||||
}
|
||||
|
||||
certificate, err := cBundle.sign()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
|
||||
if len(signedSSHCertificate) == 0 {
|
||||
return nil, errors.New("error marshaling signed certificate")
|
||||
}
|
||||
|
||||
response := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"serial_number": strconv.FormatUint(certificate.Serial, 16),
|
||||
"signed_key": string(signedSSHCertificate),
|
||||
},
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) {
|
||||
validPrincipals := ""
|
||||
validPrincipalsRaw, ok := data.GetOk("valid_principals")
|
||||
if ok {
|
||||
validPrincipals = validPrincipalsRaw.(string)
|
||||
} else {
|
||||
validPrincipals = defaultPrincipal
|
||||
}
|
||||
|
||||
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
|
||||
// Build list of allowed Principals from template and static principalsAllowedByRole
|
||||
var allowedPrincipals []string
|
||||
for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) {
|
||||
if role.AllowedUsersTemplate {
|
||||
// Look for templating markers {{ .* }}
|
||||
matched, _ := regexp.MatchString(`{{.+?}}`, principal)
|
||||
if matched {
|
||||
if req.EntityID != "" {
|
||||
// Retrieve principal based on template + entityID from request.
|
||||
templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
|
||||
if err == nil {
|
||||
// Template returned a principal
|
||||
allowedPrincipals = append(allowedPrincipals, templatePrincipal)
|
||||
} else {
|
||||
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Static principal or err template
|
||||
allowedPrincipals = append(allowedPrincipals, principal)
|
||||
}
|
||||
} else {
|
||||
// Static principal
|
||||
allowedPrincipals = append(allowedPrincipals, principal)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(parsedPrincipals) == 0:
|
||||
// There is nothing to process
|
||||
return nil, nil
|
||||
case len(allowedPrincipals) == 0:
|
||||
// User has requested principals to be set, but role is not configured
|
||||
// with any principals
|
||||
return nil, fmt.Errorf("role is not configured to allow any principals")
|
||||
default:
|
||||
// Role was explicitly configured to allow any principal.
|
||||
if principalsAllowedByRole == "*" {
|
||||
return parsedPrincipals, nil
|
||||
}
|
||||
|
||||
for _, principal := range parsedPrincipals {
|
||||
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
|
||||
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
|
||||
}
|
||||
}
|
||||
return parsedPrincipals, nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
|
||||
return func(allowedPrincipals []string, validPrincipal string) bool {
|
||||
for _, allowedPrincipal := range allowedPrincipals {
|
||||
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
|
||||
return true
|
||||
}
|
||||
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
|
||||
requestedCertificateType := data.Get("cert_type").(string)
|
||||
|
||||
var certificateType uint32
|
||||
switch requestedCertificateType {
|
||||
case "user":
|
||||
if !role.AllowUserCertificates {
|
||||
return 0, errors.New("cert_type 'user' is not allowed by role")
|
||||
}
|
||||
certificateType = ssh.UserCert
|
||||
case "host":
|
||||
if !role.AllowHostCertificates {
|
||||
return 0, errors.New("cert_type 'host' is not allowed by role")
|
||||
}
|
||||
certificateType = ssh.HostCert
|
||||
default:
|
||||
return 0, errors.New("cert_type must be either 'user' or 'host'")
|
||||
}
|
||||
|
||||
return certificateType, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
|
||||
reqID := data.Get("key_id").(string)
|
||||
|
||||
if reqID != "" {
|
||||
if !role.AllowUserKeyIDs {
|
||||
return "", fmt.Errorf("setting key_id is not allowed by role")
|
||||
}
|
||||
return reqID, nil
|
||||
}
|
||||
|
||||
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
|
||||
if req.DisplayName == "" {
|
||||
keyIDFormat = "vault-{{public_key_hash}}"
|
||||
}
|
||||
|
||||
if role.KeyIDFormat != "" {
|
||||
keyIDFormat = role.KeyIDFormat
|
||||
}
|
||||
|
||||
keyID := substQuery(keyIDFormat, map[string]string{
|
||||
"token_display_name": req.DisplayName,
|
||||
"role_name": data.Get("role").(string),
|
||||
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
|
||||
})
|
||||
|
||||
return keyID, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
|
||||
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
|
||||
if len(unparsedCriticalOptions) == 0 {
|
||||
return role.DefaultCriticalOptions, nil
|
||||
}
|
||||
|
||||
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
|
||||
|
||||
if role.AllowedCriticalOptions != "" {
|
||||
notAllowedOptions := []string{}
|
||||
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
|
||||
|
||||
for option := range criticalOptions {
|
||||
if !strutil.StrListContains(allowedCriticalOptions, option) {
|
||||
notAllowedOptions = append(notAllowedOptions, option)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notAllowedOptions) != 0 {
|
||||
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return criticalOptions, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) {
|
||||
unparsedExtensions := data.Get("extensions").(map[string]interface{})
|
||||
extensions := make(map[string]string)
|
||||
|
||||
if len(unparsedExtensions) > 0 {
|
||||
extensions := convertMapToStringValue(unparsedExtensions)
|
||||
if role.AllowedExtensions == "*" {
|
||||
// Allowed extensions was configured to allow all
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
notAllowed := []string{}
|
||||
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
|
||||
for extensionKey := range extensions {
|
||||
if !strutil.StrListContains(allowedExtensions, extensionKey) {
|
||||
notAllowed = append(notAllowed, extensionKey)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notAllowed) != 0 {
|
||||
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
|
||||
}
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
if role.DefaultExtensionsTemplate {
|
||||
for extensionKey, extensionValue := range role.DefaultExtensions {
|
||||
// Look for templating markers {{ .* }}
|
||||
matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue)
|
||||
if matched {
|
||||
if req.EntityID != "" {
|
||||
// Retrieve extension value based on template + entityID from request.
|
||||
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
|
||||
if err == nil {
|
||||
// Template returned an extension value that we can use
|
||||
extensions[extensionKey] = templateExtensionValue
|
||||
} else {
|
||||
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Static extension value or err template
|
||||
extensions[extensionKey] = extensionValue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
extensions = role.DefaultExtensions
|
||||
}
|
||||
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
|
||||
var ttl, maxTTL time.Duration
|
||||
var err error
|
||||
|
||||
ttlRaw, specifiedTTL := data.GetOk("ttl")
|
||||
if specifiedTTL {
|
||||
ttl = time.Duration(ttlRaw.(int)) * time.Second
|
||||
} else {
|
||||
ttl, err = parseutil.ParseDurationSecond(role.TTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = b.System().DefaultLeaseTTL()
|
||||
}
|
||||
|
||||
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if maxTTL == 0 {
|
||||
maxTTL = b.System().MaxLeaseTTL()
|
||||
}
|
||||
|
||||
if ttl > maxTTL {
|
||||
// Don't error if they were using system defaults, only error if
|
||||
// they specifically chose a bad TTL
|
||||
if !specifiedTTL {
|
||||
ttl = maxTTL
|
||||
} else {
|
||||
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
|
||||
if len(role.AllowedUserKeyTypesLengths) != 0 {
|
||||
var keyType string
|
||||
var keyBits int
|
||||
|
||||
switch k := publickey.(type) {
|
||||
case ssh.CryptoPublicKey:
|
||||
ff := k.CryptoPublicKey()
|
||||
switch k := ff.(type) {
|
||||
case *rsa.PublicKey:
|
||||
keyType = "rsa"
|
||||
keyBits = k.N.BitLen()
|
||||
case *dsa.PublicKey:
|
||||
keyType = "dsa"
|
||||
keyBits = k.Parameters.P.BitLen()
|
||||
case *ecdsa.PublicKey:
|
||||
keyType = "ecdsa"
|
||||
keyBits = k.Curve.Params().BitSize
|
||||
case ed25519.PublicKey:
|
||||
keyType = "ed25519"
|
||||
default:
|
||||
return fmt.Errorf("public key type of %s is not allowed", keyType)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
|
||||
}
|
||||
|
||||
keyTypeToMapKey := createKeyTypeToMapKey(keyType, keyBits)
|
||||
|
||||
var present bool
|
||||
var pass bool
|
||||
for _, kstr := range keyTypeToMapKey[keyType] {
|
||||
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
present = true
|
||||
|
||||
for _, value := range allowed_values {
|
||||
if keyType == "rsa" || keyType == "dsa" {
|
||||
// Regardless of map naming, we always need to validate the
|
||||
// bit length of RSA and DSA keys. Use the keyType flag to
|
||||
if keyBits == value {
|
||||
pass = true
|
||||
}
|
||||
} else if kstr == "ec" || kstr == "ecdsa" {
|
||||
// If the map string is "ecdsa", we have to validate the keyBits
|
||||
// are a match for an allowed value, meaning that our curve
|
||||
// is allowed. This isn't necessary when a named curve (e.g.
|
||||
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
|
||||
// because keyBits is already specified in the kstr. Thus,
|
||||
// we have conditioned around kstr and not keyType (like with
|
||||
// rsa or dsa).
|
||||
if keyBits == value {
|
||||
pass = true
|
||||
}
|
||||
} else {
|
||||
// We get here in two cases: we have a algo-named EC key
|
||||
// matching a format specifier in the key map (e.g., a P-256
|
||||
// key with a KeyAlgoECDSA256 entry in the map) or we have a
|
||||
// ed25519 key (which is always allowed).
|
||||
pass = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !present {
|
||||
return fmt.Errorf("key of type %s is not allowed", keyType)
|
||||
}
|
||||
|
||||
if !pass {
|
||||
return fmt.Errorf("key is of an invalid size: %v", keyBits)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errMsg, ok := r.(string)
|
||||
if ok {
|
||||
retCert = nil
|
||||
retErr = errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
serialNumber, err := certutil.GenerateSerialNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
|
||||
}
|
||||
|
||||
// prepare certificate for signing
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
|
||||
}
|
||||
certificate := &ssh.Certificate{
|
||||
Serial: serialNumber.Uint64(),
|
||||
Key: b.PublicKey,
|
||||
KeyId: b.KeyID,
|
||||
ValidPrincipals: b.ValidPrincipals,
|
||||
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
|
||||
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
|
||||
CertType: b.CertificateType,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: b.CriticalOptions,
|
||||
Extensions: b.Extensions,
|
||||
},
|
||||
Nonce: nonce,
|
||||
SignatureKey: sshAlgorithmSigner.PublicKey(),
|
||||
}
|
||||
|
||||
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
|
||||
out := certificate.Marshal()
|
||||
// Drop trailing signature length.
|
||||
certificateBytes := out[:len(out)-4]
|
||||
|
||||
algo := b.Role.AlgorithmSigner
|
||||
|
||||
// Handle the new default algorithm selection process correctly.
|
||||
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
|
||||
algo = ssh.SigAlgoRSASHA2256
|
||||
} else if algo == DefaultAlgorithmSigner {
|
||||
algo = ""
|
||||
}
|
||||
|
||||
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
|
||||
}
|
||||
|
||||
certificate.Signature = sig
|
||||
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
func createKeyTypeToMapKey(keyType string, keyBits int) map[string][]string {
|
||||
keyTypeToMapKey := map[string][]string{
|
||||
"rsa": {"rsa", ssh.KeyAlgoRSA},
|
||||
"dsa": {"dsa", ssh.KeyAlgoDSA},
|
||||
"ecdsa": {"ecdsa", "ec"},
|
||||
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
|
||||
}
|
||||
|
||||
if keyType == "ecdsa" {
|
||||
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
|
||||
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
|
||||
}
|
||||
}
|
||||
|
||||
return keyTypeToMapKey
|
||||
}
|
|
@ -2,40 +2,12 @@ package ssh
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type creationBundle struct {
|
||||
KeyID string
|
||||
ValidPrincipals []string
|
||||
PublicKey ssh.PublicKey
|
||||
CertificateType uint32
|
||||
TTL time.Duration
|
||||
Signer ssh.Signer
|
||||
Role *sshRole
|
||||
CriticalOptions map[string]string
|
||||
Extensions map[string]string
|
||||
}
|
||||
|
||||
func pathSign(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "sign/" + framework.GenericNameWithAtRegex("role"),
|
||||
|
@ -120,497 +92,10 @@ func (b *backend) pathSignCertificate(ctx context.Context, req *logical.Request,
|
|||
return logical.ErrorResponse(fmt.Sprintf("public_key failed to meet the key requirements: %s", err)), nil
|
||||
}
|
||||
|
||||
// Note that these various functions always return "user errors" so we pass
|
||||
// them as 4xx values
|
||||
keyID, err := b.calculateKeyID(data, req, role, userPublicKey)
|
||||
response, err := b.pathSignIssueCertificateHelper(ctx, req, data, role, userPublicKey)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
certificateType, err := b.calculateCertificateType(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
var parsedPrincipals []string
|
||||
if certificateType == ssh.HostCert {
|
||||
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, "", role.AllowedDomains, validateValidPrincipalForHosts(role))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
} else {
|
||||
parsedPrincipals, err = b.calculateValidPrincipals(data, req, role, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
}
|
||||
|
||||
ttl, err := b.calculateTTL(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
criticalOptions, err := b.calculateCriticalOptions(data, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
extensions, err := b.calculateExtensions(data, req, role)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
privateKeyEntry, err := caKey(ctx, req.Storage, caPrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA private key: %w", err)
|
||||
}
|
||||
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
|
||||
return nil, fmt.Errorf("failed to read CA private key")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse stored CA private key: %w", err)
|
||||
}
|
||||
|
||||
cBundle := creationBundle{
|
||||
KeyID: keyID,
|
||||
PublicKey: userPublicKey,
|
||||
Signer: signer,
|
||||
ValidPrincipals: parsedPrincipals,
|
||||
TTL: ttl,
|
||||
CertificateType: certificateType,
|
||||
Role: role,
|
||||
CriticalOptions: criticalOptions,
|
||||
Extensions: extensions,
|
||||
}
|
||||
|
||||
certificate, err := cBundle.sign()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
|
||||
if len(signedSSHCertificate) == 0 {
|
||||
return nil, fmt.Errorf("error marshaling signed certificate")
|
||||
}
|
||||
|
||||
response := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"serial_number": strconv.FormatUint(certificate.Serial, 16),
|
||||
"signed_key": string(signedSSHCertificate),
|
||||
},
|
||||
return response, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateValidPrincipals(data *framework.FieldData, req *logical.Request, role *sshRole, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) {
|
||||
validPrincipals := ""
|
||||
validPrincipalsRaw, ok := data.GetOk("valid_principals")
|
||||
if ok {
|
||||
validPrincipals = validPrincipalsRaw.(string)
|
||||
} else {
|
||||
validPrincipals = defaultPrincipal
|
||||
}
|
||||
|
||||
parsedPrincipals := strutil.RemoveDuplicates(strutil.ParseStringSlice(validPrincipals, ","), false)
|
||||
// Build list of allowed Principals from template and static principalsAllowedByRole
|
||||
var allowedPrincipals []string
|
||||
for _, principal := range strutil.RemoveDuplicates(strutil.ParseStringSlice(principalsAllowedByRole, ","), false) {
|
||||
if role.AllowedUsersTemplate {
|
||||
// Look for templating markers {{ .* }}
|
||||
matched, _ := regexp.MatchString(`{{.+?}}`, principal)
|
||||
if matched {
|
||||
if req.EntityID != "" {
|
||||
// Retrieve principal based on template + entityID from request.
|
||||
templatePrincipal, err := framework.PopulateIdentityTemplate(principal, req.EntityID, b.System())
|
||||
if err == nil {
|
||||
// Template returned a principal
|
||||
allowedPrincipals = append(allowedPrincipals, templatePrincipal)
|
||||
} else {
|
||||
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", principal, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Static principal or err template
|
||||
allowedPrincipals = append(allowedPrincipals, principal)
|
||||
}
|
||||
} else {
|
||||
// Static principal
|
||||
allowedPrincipals = append(allowedPrincipals, principal)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(parsedPrincipals) == 0:
|
||||
// There is nothing to process
|
||||
return nil, nil
|
||||
case len(allowedPrincipals) == 0:
|
||||
// User has requested principals to be set, but role is not configured
|
||||
// with any principals
|
||||
return nil, fmt.Errorf("role is not configured to allow any principals")
|
||||
default:
|
||||
// Role was explicitly configured to allow any principal.
|
||||
if principalsAllowedByRole == "*" {
|
||||
return parsedPrincipals, nil
|
||||
}
|
||||
|
||||
for _, principal := range parsedPrincipals {
|
||||
if !validatePrincipal(strutil.RemoveDuplicates(allowedPrincipals, false), principal) {
|
||||
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
|
||||
}
|
||||
}
|
||||
return parsedPrincipals, nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
|
||||
return func(allowedPrincipals []string, validPrincipal string) bool {
|
||||
for _, allowedPrincipal := range allowedPrincipals {
|
||||
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
|
||||
return true
|
||||
}
|
||||
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
|
||||
requestedCertificateType := data.Get("cert_type").(string)
|
||||
|
||||
var certificateType uint32
|
||||
switch requestedCertificateType {
|
||||
case "user":
|
||||
if !role.AllowUserCertificates {
|
||||
return 0, errors.New("cert_type 'user' is not allowed by role")
|
||||
}
|
||||
certificateType = ssh.UserCert
|
||||
case "host":
|
||||
if !role.AllowHostCertificates {
|
||||
return 0, errors.New("cert_type 'host' is not allowed by role")
|
||||
}
|
||||
certificateType = ssh.HostCert
|
||||
default:
|
||||
return 0, errors.New("cert_type must be either 'user' or 'host'")
|
||||
}
|
||||
|
||||
return certificateType, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateKeyID(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) {
|
||||
reqID := data.Get("key_id").(string)
|
||||
|
||||
if reqID != "" {
|
||||
if !role.AllowUserKeyIDs {
|
||||
return "", fmt.Errorf("setting key_id is not allowed by role")
|
||||
}
|
||||
return reqID, nil
|
||||
}
|
||||
|
||||
keyIDFormat := "vault-{{token_display_name}}-{{public_key_hash}}"
|
||||
if req.DisplayName == "" {
|
||||
keyIDFormat = "vault-{{public_key_hash}}"
|
||||
}
|
||||
|
||||
if role.KeyIDFormat != "" {
|
||||
keyIDFormat = role.KeyIDFormat
|
||||
}
|
||||
|
||||
keyID := substQuery(keyIDFormat, map[string]string{
|
||||
"token_display_name": req.DisplayName,
|
||||
"role_name": data.Get("role").(string),
|
||||
"public_key_hash": fmt.Sprintf("%x", sha256.Sum256(pubKey.Marshal())),
|
||||
})
|
||||
|
||||
return keyID, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
|
||||
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
|
||||
if len(unparsedCriticalOptions) == 0 {
|
||||
return role.DefaultCriticalOptions, nil
|
||||
}
|
||||
|
||||
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
|
||||
|
||||
if role.AllowedCriticalOptions != "" {
|
||||
notAllowedOptions := []string{}
|
||||
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
|
||||
|
||||
for option := range criticalOptions {
|
||||
if !strutil.StrListContains(allowedCriticalOptions, option) {
|
||||
notAllowedOptions = append(notAllowedOptions, option)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notAllowedOptions) != 0 {
|
||||
return nil, fmt.Errorf("critical options not on allowed list: %v", notAllowedOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return criticalOptions, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) {
|
||||
unparsedExtensions := data.Get("extensions").(map[string]interface{})
|
||||
extensions := make(map[string]string)
|
||||
|
||||
if len(unparsedExtensions) > 0 {
|
||||
extensions := convertMapToStringValue(unparsedExtensions)
|
||||
if role.AllowedExtensions == "*" {
|
||||
// Allowed extensions was configured to allow all
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
notAllowed := []string{}
|
||||
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
|
||||
for extensionKey := range extensions {
|
||||
if !strutil.StrListContains(allowedExtensions, extensionKey) {
|
||||
notAllowed = append(notAllowed, extensionKey)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notAllowed) != 0 {
|
||||
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
|
||||
}
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
if role.DefaultExtensionsTemplate {
|
||||
for extensionKey, extensionValue := range role.DefaultExtensions {
|
||||
// Look for templating markers {{ .* }}
|
||||
matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue)
|
||||
if matched {
|
||||
if req.EntityID != "" {
|
||||
// Retrieve extension value based on template + entityID from request.
|
||||
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
|
||||
if err == nil {
|
||||
// Template returned an extension value that we can use
|
||||
extensions[extensionKey] = templateExtensionValue
|
||||
} else {
|
||||
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Static extension value or err template
|
||||
extensions[extensionKey] = extensionValue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
extensions = role.DefaultExtensions
|
||||
}
|
||||
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
|
||||
var ttl, maxTTL time.Duration
|
||||
var err error
|
||||
|
||||
ttlRaw, specifiedTTL := data.GetOk("ttl")
|
||||
if specifiedTTL {
|
||||
ttl = time.Duration(ttlRaw.(int)) * time.Second
|
||||
} else {
|
||||
ttl, err = parseutil.ParseDurationSecond(role.TTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = b.System().DefaultLeaseTTL()
|
||||
}
|
||||
|
||||
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if maxTTL == 0 {
|
||||
maxTTL = b.System().MaxLeaseTTL()
|
||||
}
|
||||
|
||||
if ttl > maxTTL {
|
||||
// Don't error if they were using system defaults, only error if
|
||||
// they specifically chose a bad TTL
|
||||
if !specifiedTTL {
|
||||
ttl = maxTTL
|
||||
} else {
|
||||
return 0, fmt.Errorf("ttl is larger than maximum allowed %d", maxTTL/time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
func (b *backend) validateSignedKeyRequirements(publickey ssh.PublicKey, role *sshRole) error {
|
||||
if len(role.AllowedUserKeyTypesLengths) != 0 {
|
||||
var keyType string
|
||||
var keyBits int
|
||||
|
||||
switch k := publickey.(type) {
|
||||
case ssh.CryptoPublicKey:
|
||||
ff := k.CryptoPublicKey()
|
||||
switch k := ff.(type) {
|
||||
case *rsa.PublicKey:
|
||||
keyType = "rsa"
|
||||
keyBits = k.N.BitLen()
|
||||
case *dsa.PublicKey:
|
||||
keyType = "dsa"
|
||||
keyBits = k.Parameters.P.BitLen()
|
||||
case *ecdsa.PublicKey:
|
||||
keyType = "ecdsa"
|
||||
keyBits = k.Curve.Params().BitSize
|
||||
case ed25519.PublicKey:
|
||||
keyType = "ed25519"
|
||||
default:
|
||||
return fmt.Errorf("public key type of %s is not allowed", keyType)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("pubkey not suitable for crypto (expected ssh.CryptoPublicKey but found %T)", k)
|
||||
}
|
||||
|
||||
keyTypeToMapKey := map[string][]string{
|
||||
"rsa": {"rsa", ssh.KeyAlgoRSA},
|
||||
"dsa": {"dsa", ssh.KeyAlgoDSA},
|
||||
"ecdsa": {"ecdsa", "ec"},
|
||||
"ed25519": {"ed25519", ssh.KeyAlgoED25519},
|
||||
}
|
||||
|
||||
if keyType == "ecdsa" {
|
||||
ecCurveBitsToAlgoName := map[int]string{
|
||||
256: ssh.KeyAlgoECDSA256,
|
||||
384: ssh.KeyAlgoECDSA384,
|
||||
521: ssh.KeyAlgoECDSA521,
|
||||
}
|
||||
|
||||
if algo, ok := ecCurveBitsToAlgoName[keyBits]; ok {
|
||||
keyTypeToMapKey[keyType] = append(keyTypeToMapKey[keyType], algo)
|
||||
}
|
||||
|
||||
// If the algorithm is not found, it could be that we have a curve
|
||||
// that we haven't added a constant for yet. But they could allow it
|
||||
// (assuming x/crypto/ssh can parse it) via setting a ec: <keyBits>
|
||||
// mapping rather than using a named SSH key type, so erring out here
|
||||
// isn't advisable.
|
||||
}
|
||||
|
||||
var present bool
|
||||
var pass bool
|
||||
for _, kstr := range keyTypeToMapKey[keyType] {
|
||||
allowed_values, ok := role.AllowedUserKeyTypesLengths[kstr]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
present = true
|
||||
|
||||
for _, value := range allowed_values {
|
||||
if keyType == "rsa" || keyType == "dsa" {
|
||||
// Regardless of map naming, we always need to validate the
|
||||
// bit length of RSA and DSA keys. Use the keyType flag to
|
||||
if keyBits == value {
|
||||
pass = true
|
||||
}
|
||||
} else if kstr == "ec" || kstr == "ecdsa" {
|
||||
// If the map string is "ecdsa", we have to validate the keyBits
|
||||
// are a match for an allowed value, meaning that our curve
|
||||
// is allowed. This isn't necessary when a named curve (e.g.
|
||||
// ssh.KeyAlgoECDSA256) is allowed (and hence kstr is that),
|
||||
// because keyBits is already specified in the kstr. Thus,
|
||||
// we have conditioned around kstr and not keyType (like with
|
||||
// rsa or dsa).
|
||||
if keyBits == value {
|
||||
pass = true
|
||||
}
|
||||
} else {
|
||||
// We get here in two cases: we have a algo-named EC key
|
||||
// matching a format specifier in the key map (e.g., a P-256
|
||||
// key with a KeyAlgoECDSA256 entry in the map) or we have a
|
||||
// ed25519 key (which is always allowed).
|
||||
pass = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !present {
|
||||
return fmt.Errorf("key of type %s is not allowed", keyType)
|
||||
}
|
||||
|
||||
if !pass {
|
||||
return fmt.Errorf("key is of an invalid size: %v", keyBits)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errMsg, ok := r.(string)
|
||||
if ok {
|
||||
retCert = nil
|
||||
retErr = errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
serialNumber, err := certutil.GenerateSerialNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
sshAlgorithmSigner, ok := b.Signer.(ssh.AlgorithmSigner)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: signer is not an AlgorithmSigner")
|
||||
}
|
||||
|
||||
// prepare certificate for signing
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: error generating random nonce")
|
||||
}
|
||||
certificate := &ssh.Certificate{
|
||||
Serial: serialNumber.Uint64(),
|
||||
Key: b.PublicKey,
|
||||
KeyId: b.KeyID,
|
||||
ValidPrincipals: b.ValidPrincipals,
|
||||
ValidAfter: uint64(now.Add(-b.Role.NotBeforeDuration).In(time.UTC).Unix()),
|
||||
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
|
||||
CertType: b.CertificateType,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: b.CriticalOptions,
|
||||
Extensions: b.Extensions,
|
||||
},
|
||||
Nonce: nonce,
|
||||
SignatureKey: sshAlgorithmSigner.PublicKey(),
|
||||
}
|
||||
|
||||
// get bytes to sign; this is based on Certificate.bytesForSigning() from the go ssh lib
|
||||
out := certificate.Marshal()
|
||||
// Drop trailing signature length.
|
||||
certificateBytes := out[:len(out)-4]
|
||||
|
||||
algo := b.Role.AlgorithmSigner
|
||||
|
||||
// Handle the new default algorithm selection process correctly.
|
||||
if algo == DefaultAlgorithmSigner && sshAlgorithmSigner.PublicKey().Type() == ssh.KeyAlgoRSA {
|
||||
algo = ssh.SigAlgoRSASHA2256
|
||||
} else if algo == DefaultAlgorithmSigner {
|
||||
algo = ""
|
||||
}
|
||||
|
||||
sig, err := sshAlgorithmSigner.SignWithAlgorithm(rand.Reader, certificateBytes, algo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate signed SSH key: sign error: %w", err)
|
||||
}
|
||||
|
||||
certificate.Signature = sig
|
||||
|
||||
return certificate, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
ssh: Addition of an endpoint `ssh/issue/:role` to allow the creation of signed key pairs
|
||||
```
|
|
@ -873,3 +873,88 @@ $ curl \
|
|||
"auth": null
|
||||
}
|
||||
```
|
||||
|
||||
## Generate Certificate and Key
|
||||
|
||||
This endpoint generates a new set of credentials (private key and certificate)
|
||||
based on the role named in the endpoint.
|
||||
|
||||
~> **Note**: The private key is _not_ stored. If you do not save the private
|
||||
key from the response, you will need to request a new certificate.
|
||||
|
||||
| Method | Path |
|
||||
| :----- | :---------------- |
|
||||
| `POST` | `/ssh/issue/:name` |
|
||||
|
||||
### Parameters
|
||||
|
||||
- `name` `(string: <required>)` – Specifies the name of the role to create the
|
||||
certificate against. This is part of the request URL.
|
||||
|
||||
- `key_type` `(string: "rsa")` – Specifies the desired key type; must be `rsa`, `ed25519`
|
||||
or `ec`.
|
||||
|
||||
- `key_bits` `(int: 0)` – Specifies the number of bits to use for the
|
||||
generated keys. Allowed values are 0 (universal default); with
|
||||
`key_type=rsa`, allowed values are: 2048 (default), 3072, or
|
||||
4096; with `key_type=ec`, allowed values are: 224, 256 (default),
|
||||
384, or 521; ignored with `key_type=ed25519`.
|
||||
|
||||
- `ttl` `(string: "")` – Specifies the Requested Time To Live. Cannot be greater
|
||||
than the role's `max_ttl` value. If not provided, the role's `ttl` value will
|
||||
be used. Note that the role values default to system values if not explicitly
|
||||
set.
|
||||
|
||||
- `valid_principals` `(string: "")` – Specifies valid principals, either
|
||||
usernames or hostnames, that the certificate should be signed for.
|
||||
|
||||
- `cert_type` `(string: "user")` – Specifies the type of certificate to be
|
||||
created; either "user" or "host".
|
||||
|
||||
- `key_id` `(string: "")` – Specifies the key id that the created certificate
|
||||
should have. If not specified, the display name of the token will be used.
|
||||
|
||||
- `critical_options` `(map<string|string>: "")` – Specifies a map of the
|
||||
critical options that the certificate should be signed for. Defaults to none.
|
||||
|
||||
- `extensions` `(map<string|string>: "")` – Specifies a map of the extensions
|
||||
that the certificate should be signed for. Defaults to none.
|
||||
|
||||
### Sample Payload
|
||||
|
||||
```json
|
||||
{
|
||||
"key_type": "rsa",
|
||||
"key_bits": 2048
|
||||
}
|
||||
```
|
||||
|
||||
### Sample Request
|
||||
|
||||
```shell-session
|
||||
$ curl \
|
||||
--header "X-Vault-Token: ..." \
|
||||
--request POST \
|
||||
--data @payload.json \
|
||||
http://127.0.0.1:8200/v1/ssh/issue/my-role
|
||||
```
|
||||
|
||||
### Sample Response
|
||||
|
||||
```json
|
||||
{
|
||||
"request_id": "94fd1102-08a1-c207-0e3e-657e8f80c09e",
|
||||
"lease_id": "",
|
||||
"renewable": false,
|
||||
"lease_duration": 0,
|
||||
"data": {
|
||||
"serial_number": "1e965817eb12a511",
|
||||
"signed_key": "ssh-rsa-cert-v01@openssh.com AAAAHHN...\n",
|
||||
"private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpQIBAAKCAQEAwer03vkQrPV+wWpbisJJv2CKqHmMz+Ej0ctLbhpOmR2CY9S9\n...\nQN351pgTphi6nlCkGPzkDuwvtxSxiCWXQcaxrHAL7MiJpPzkIBq1\n-----END RSA PRIVATE KEY-----\n",
|
||||
"private_key_type": "rsa"
|
||||
},
|
||||
"wrap_info": null,
|
||||
"warnings": null,
|
||||
"auth": null
|
||||
}
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue