Switch cert to tokenutil (#7037)

This commit is contained in:
Jeff Mitchell 2019-07-01 16:31:37 -04:00 committed by GitHub
parent 18a4ab1db5
commit 25f676b42e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 302 additions and 136 deletions

View File

@ -11,6 +11,7 @@ import (
"net/url"
"path/filepath"
"github.com/go-test/deep"
"github.com/hashicorp/go-sockaddr"
"golang.org/x/net/http2"
@ -39,6 +40,7 @@ import (
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
@ -1949,3 +1951,60 @@ func Test_Renew(t *testing.T) {
t.Fatal("expected error")
}
}
func TestBackend_CertUpgrade(t *testing.T) {
s := &logical.InmemStorage{}
config := logical.TestBackendConfig()
config.StorageView = s
ctx := context.Background()
b := Backend()
if b == nil {
t.Fatalf("failed to create backend")
}
if err := b.Setup(ctx, config); err != nil {
t.Fatal(err)
}
foo := &CertEntry{
Policies: []string{"foo"},
Period: time.Second,
TTL: time.Second,
MaxTTL: time.Second,
BoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}},
}
entry, err := logical.StorageEntryJSON("cert/foo", foo)
if err != nil {
t.Fatal(err)
}
err = s.Put(ctx, entry)
if err != nil {
t.Fatal(err)
}
certEntry, err := b.Cert(ctx, s, "foo")
if err != nil {
t.Fatal(err)
}
exp := &CertEntry{
Policies: []string{"foo"},
Period: time.Second,
TTL: time.Second,
MaxTTL: time.Second,
BoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}},
TokenParams: tokenutil.TokenParams{
TokenPolicies: []string{"foo"},
TokenPeriod: time.Second,
TokenTTL: time.Second,
TokenMaxTTL: time.Second,
TokenBoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}},
},
}
if diff := deep.Equal(certEntry, exp); diff != nil {
t.Fatal(diff)
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
)
@ -28,7 +29,7 @@ func pathListCerts(b *backend) *framework.Path {
}
func pathCerts(b *backend) *framework.Path {
return &framework.Path{
p := &framework.Path{
Pattern: "certs/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
@ -95,39 +96,38 @@ certificate.`,
"policies": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies.",
Description: tokenutil.DeprecationText("token_policies"),
Deprecated: true,
},
"lease": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Deprecated: use "ttl" instead. TTL time in
seconds. Defaults to system/backend default TTL.`,
Type: framework.TypeInt,
Description: tokenutil.DeprecationText("token_ttl"),
Deprecated: true,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `TTL for tokens issued by this backend.
Defaults to system/backend default TTL time.`,
Type: framework.TypeDurationSecond,
Description: tokenutil.DeprecationText("token_ttl"),
Deprecated: true,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration in either an integer number of seconds (3600) or
an integer time unit (60m) after which the
issued token can no longer be renewed.`,
Type: framework.TypeDurationSecond,
Description: tokenutil.DeprecationText("token_max_ttl"),
Deprecated: true,
},
"period": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `If set, indicates that the token generated using this role
should never expire. The token should be renewed within the
duration specified by this value. At each renewal, the token's
TTL will be set to the value of this parameter.`,
Type: framework.TypeDurationSecond,
Description: tokenutil.DeprecationText("token_period"),
Deprecated: true,
},
"bound_cidrs": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `Comma separated string or list of CIDR blocks. If set, specifies the blocks of
IP addresses which can perform the login operation.`,
Type: framework.TypeCommaStringSlice,
Description: tokenutil.DeprecationText("token_bound_cidrs"),
Deprecated: true,
},
},
@ -140,6 +140,9 @@ IP addresses which can perform the login operation.`,
HelpSynopsis: pathCertHelpSyn,
HelpDescription: pathCertHelpDesc,
}
tokenutil.AddTokenFields(p.Fields)
return p
}
func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertEntry, error) {
@ -155,6 +158,23 @@ func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertE
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
if result.TokenTTL == 0 && result.TTL > 0 {
result.TokenTTL = result.TTL
}
if result.TokenMaxTTL == 0 && result.MaxTTL > 0 {
result.TokenMaxTTL = result.MaxTTL
}
if result.TokenPeriod == 0 && result.Period > 0 {
result.TokenPeriod = result.Period
}
if len(result.TokenPolicies) == 0 && len(result.Policies) > 0 {
result.TokenPolicies = result.Policies
}
if len(result.TokenBoundCIDRs) == 0 && len(result.BoundCIDRs) > 0 {
result.TokenBoundCIDRs = result.BoundCIDRs
}
return &result, nil
}
@ -183,86 +203,202 @@ func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *fra
return nil, nil
}
data := map[string]interface{}{
"certificate": cert.Certificate,
"display_name": cert.DisplayName,
"allowed_names": cert.AllowedNames,
"allowed_common_names": cert.AllowedCommonNames,
"allowed_dns_sans": cert.AllowedDNSSANs,
"allowed_email_sans": cert.AllowedEmailSANs,
"allowed_uri_sans": cert.AllowedURISANs,
"allowed_organizational_units": cert.AllowedOrganizationalUnits,
"required_extensions": cert.RequiredExtensions,
}
cert.PopulateTokenData(data)
if cert.TTL > 0 {
data["ttl"] = int64(cert.TTL.Seconds())
}
if cert.MaxTTL > 0 {
data["max_ttl"] = int64(cert.MaxTTL.Seconds())
}
if cert.Period > 0 {
data["period"] = int64(cert.Period.Seconds())
}
if len(cert.Policies) > 0 {
data["policies"] = data["token_policies"]
}
if len(cert.BoundCIDRs) > 0 {
data["bound_cidrs"] = data["token_bound_cidrs"]
}
return &logical.Response{
Data: map[string]interface{}{
"certificate": cert.Certificate,
"display_name": cert.DisplayName,
"policies": cert.Policies,
"ttl": cert.TTL / time.Second,
"max_ttl": cert.MaxTTL / time.Second,
"period": cert.Period / time.Second,
"allowed_names": cert.AllowedNames,
"allowed_common_names": cert.AllowedCommonNames,
"allowed_dns_sans": cert.AllowedDNSSANs,
"allowed_email_sans": cert.AllowedEmailSANs,
"allowed_uri_sans": cert.AllowedURISANs,
"allowed_organizational_units": cert.AllowedOrganizationalUnits,
"required_extensions": cert.RequiredExtensions,
"bound_cidrs": cert.BoundCIDRs,
},
Data: data,
}, nil
}
func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := strings.ToLower(d.Get("name").(string))
certificate := d.Get("certificate").(string)
displayName := d.Get("display_name").(string)
policies := policyutil.ParsePolicies(d.Get("policies"))
allowedNames := d.Get("allowed_names").([]string)
allowedCommonNames := d.Get("allowed_common_names").([]string)
allowedDNSSANs := d.Get("allowed_dns_sans").([]string)
allowedEmailSANs := d.Get("allowed_email_sans").([]string)
allowedURISANs := d.Get("allowed_uri_sans").([]string)
allowedOrganizationalUnits := d.Get("allowed_organizational_units").([]string)
requiredExtensions := d.Get("required_extensions").([]string)
cert, err := b.Cert(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if cert == nil {
cert = &CertEntry{
Name: name,
}
}
// Get non tokenutil fields
if certificateRaw, ok := d.GetOk("certificate"); ok {
cert.Certificate = certificateRaw.(string)
}
if displayNameRaw, ok := d.GetOk("display_name"); ok {
cert.DisplayName = displayNameRaw.(string)
}
if allowedNamesRaw, ok := d.GetOk("allowed_names"); ok {
cert.AllowedNames = allowedNamesRaw.([]string)
}
if allowedCommonNamesRaw, ok := d.GetOk("allowed_common_names"); ok {
cert.AllowedCommonNames = allowedCommonNamesRaw.([]string)
}
if allowedDNSSANsRaw, ok := d.GetOk("allowed_dns_sans"); ok {
cert.AllowedDNSSANs = allowedDNSSANsRaw.([]string)
}
if allowedEmailSANsRaw, ok := d.GetOk("allowed_email_sans"); ok {
cert.AllowedEmailSANs = allowedEmailSANsRaw.([]string)
}
if allowedURISANsRaw, ok := d.GetOk("allowed_uri_sans"); ok {
cert.AllowedURISANs = allowedURISANsRaw.([]string)
}
if allowedOrganizationalUnitsRaw, ok := d.GetOk("allowed_organizational_units"); ok {
cert.AllowedOrganizationalUnits = allowedOrganizationalUnitsRaw.([]string)
}
if requiredExtensionsRaw, ok := d.GetOk("required_extensions"); ok {
cert.RequiredExtensions = requiredExtensionsRaw.([]string)
}
// Get tokenutil fields
if err := cert.ParseTokenFields(req, d); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
// Handle upgrade cases
{
policiesRaw, ok := d.GetOk("token_policies")
if !ok {
policiesRaw, ok = d.GetOk("policies")
if ok {
cert.Policies = policyutil.ParsePolicies(policiesRaw)
cert.TokenPolicies = cert.Policies
}
} else {
_, ok = d.GetOk("policies")
if ok {
cert.Policies = cert.TokenPolicies
} else {
cert.Policies = nil
}
}
ttlRaw, ok := d.GetOk("token_ttl")
if !ok {
ttlRaw, ok = d.GetOk("ttl")
if !ok {
ttlRaw, ok = d.GetOk("lease")
}
if ok {
cert.TTL = time.Duration(ttlRaw.(int)) * time.Second
cert.TokenTTL = cert.TTL
}
} else {
_, ok = d.GetOk("ttl")
if ok {
cert.TTL = cert.TokenTTL
} else {
cert.TTL = 0
}
}
maxTTLRaw, ok := d.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = d.GetOk("max_ttl")
if ok {
cert.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
cert.TokenMaxTTL = cert.MaxTTL
}
} else {
_, ok = d.GetOk("max_ttl")
if ok {
cert.MaxTTL = cert.TokenMaxTTL
} else {
cert.MaxTTL = 0
}
}
periodRaw, ok := d.GetOk("token_period")
if !ok {
periodRaw, ok = d.GetOk("period")
if ok {
cert.Period = time.Duration(periodRaw.(int)) * time.Second
cert.TokenPeriod = cert.Period
}
} else {
_, ok = d.GetOk("period")
if ok {
cert.Period = cert.TokenPeriod
} else {
cert.Period = 0
}
}
boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs")
if !ok {
boundCIDRsRaw, ok = d.GetOk("bound_cidrs")
if ok {
boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
cert.BoundCIDRs = boundCIDRs
cert.TokenBoundCIDRs = cert.BoundCIDRs
}
} else {
_, ok = d.GetOk("bound_cidrs")
if ok {
cert.BoundCIDRs = cert.TokenBoundCIDRs
} else {
cert.BoundCIDRs = nil
}
}
}
var resp logical.Response
// Parse the ttl (or lease duration)
systemDefaultTTL := b.System().DefaultLeaseTTL()
ttl := time.Duration(d.Get("ttl").(int)) * time.Second
if ttl == 0 {
ttl = time.Duration(d.Get("lease").(int)) * time.Second
if cert.TokenTTL > systemDefaultTTL {
resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", cert.TokenTTL/time.Second, systemDefaultTTL/time.Second))
}
if ttl > systemDefaultTTL {
resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", ttl/time.Second, systemDefaultTTL/time.Second))
}
if ttl < time.Duration(0) {
return logical.ErrorResponse("ttl cannot be negative"), nil
}
// Parse max_ttl
systemMaxTTL := b.System().MaxLeaseTTL()
maxTTL := time.Duration(d.Get("max_ttl").(int)) * time.Second
if maxTTL > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", maxTTL/time.Second, systemMaxTTL/time.Second))
if cert.TokenMaxTTL > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", cert.TokenMaxTTL/time.Second, systemMaxTTL/time.Second))
}
if maxTTL < time.Duration(0) {
return logical.ErrorResponse("max_ttl cannot be negative"), nil
}
if maxTTL != 0 && ttl > maxTTL {
if cert.TokenMaxTTL != 0 && cert.TokenTTL > cert.TokenMaxTTL {
return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil
}
// Parse period
period := time.Duration(d.Get("period").(int)) * time.Second
if period > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", period/time.Second, systemMaxTTL/time.Second))
}
if period < time.Duration(0) {
return logical.ErrorResponse("period cannot be negative"), nil
if cert.TokenPeriod > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", cert.TokenPeriod/time.Second, systemMaxTTL/time.Second))
}
// Default the display name to the certificate name if not given
if displayName == "" {
displayName = name
if cert.DisplayName == "" {
cert.DisplayName = name
}
parsed := parsePEM([]byte(certificate))
parsed := parsePEM([]byte(cert.Certificate))
if len(parsed) == 0 {
return logical.ErrorResponse("failed to parse certificate"), nil
}
@ -281,31 +417,8 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
}
}
parsedCIDRs, err := parseutil.ParseAddrs(d.Get("bound_cidrs"))
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
certEntry := &CertEntry{
Name: name,
Certificate: certificate,
DisplayName: displayName,
Policies: policies,
AllowedNames: allowedNames,
AllowedCommonNames: allowedCommonNames,
AllowedDNSSANs: allowedDNSSANs,
AllowedEmailSANs: allowedEmailSANs,
AllowedURISANs: allowedURISANs,
AllowedOrganizationalUnits: allowedOrganizationalUnits,
RequiredExtensions: requiredExtensions,
TTL: ttl,
MaxTTL: maxTTL,
Period: period,
BoundCIDRs: parsedCIDRs,
}
// Store it
entry, err := logical.StorageEntryJSON("cert/"+name, certEntry)
entry, err := logical.StorageEntryJSON("cert/"+name, cert)
if err != nil {
return nil, err
}
@ -321,6 +434,8 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
}
type CertEntry struct {
tokenutil.TokenParams
Name string
Certificate string
DisplayName string

View File

@ -83,36 +83,28 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
resp := &logical.Response{
Auth: &logical.Auth{
Period: matched.Entry.Period,
InternalData: map[string]interface{}{
"subject_key_id": skid,
"authority_key_id": akid,
},
Policies: matched.Entry.Policies,
DisplayName: matched.Entry.DisplayName,
Metadata: map[string]string{
"cert_name": matched.Entry.Name,
"common_name": clientCerts[0].Subject.CommonName,
"serial_number": clientCerts[0].SerialNumber.String(),
"subject_key_id": certutil.GetHexFormatted(clientCerts[0].SubjectKeyId, ":"),
"authority_key_id": certutil.GetHexFormatted(clientCerts[0].AuthorityKeyId, ":"),
},
LeaseOptions: logical.LeaseOptions{
Renewable: true,
TTL: matched.Entry.TTL,
MaxTTL: matched.Entry.MaxTTL,
},
Alias: &logical.Alias{
Name: clientCerts[0].Subject.CommonName,
},
BoundCIDRs: matched.Entry.BoundCIDRs,
auth := &logical.Auth{
InternalData: map[string]interface{}{
"subject_key_id": skid,
"authority_key_id": akid,
},
DisplayName: matched.Entry.DisplayName,
Metadata: map[string]string{
"cert_name": matched.Entry.Name,
"common_name": clientCerts[0].Subject.CommonName,
"serial_number": clientCerts[0].SerialNumber.String(),
"subject_key_id": certutil.GetHexFormatted(clientCerts[0].SubjectKeyId, ":"),
"authority_key_id": certutil.GetHexFormatted(clientCerts[0].AuthorityKeyId, ":"),
},
Alias: &logical.Alias{
Name: clientCerts[0].Subject.CommonName,
},
}
matched.Entry.PopulateTokenAuth(auth)
// Generate a response
return resp, nil
return &logical.Response{
Auth: auth,
}, nil
}
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
@ -159,14 +151,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, nil
}
if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.TokenPolicies) {
if !policyutil.EquivalentPolicies(cert.TokenPolicies, req.Auth.TokenPolicies) {
return nil, fmt.Errorf("policies have changed, not renewing")
}
resp := &logical.Response{Auth: req.Auth}
resp.Auth.TTL = cert.TTL
resp.Auth.MaxTTL = cert.MaxTTL
resp.Auth.Period = cert.Period
resp.Auth.TTL = cert.TokenTTL
resp.Auth.MaxTTL = cert.TokenMaxTTL
resp.Auth.Period = cert.TokenPeriod
return resp, nil
}
@ -478,7 +470,7 @@ func (b *backend) checkForValidChain(chains [][]*x509.Certificate) bool {
}
func (b *backend) checkCIDR(cert *CertEntry, req *logical.Request) error {
if cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, cert.BoundCIDRs) {
if cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, cert.TokenBoundCIDRs) {
return nil
}
return logical.ErrPermissionDenied