Rework and refactoring
This commit is contained in:
parent
3aeae62c00
commit
5996c3e9d8
|
@ -91,7 +91,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(req *logical.Request, data
|
|||
if certName == "" {
|
||||
return false, fmt.Errorf("missing cert_name")
|
||||
}
|
||||
entry, err := awsPublicCertificateEntry(req.Storage, certName)
|
||||
entry, err := b.awsPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ func decodePEMAndParseCertificate(certificate string) (*x509.Certificate, error)
|
|||
// awsPublicCertificates returns a slice of all the parsed AWS public
|
||||
// certificates, that were registered using `config/certificate/<cert_name>` endpoint.
|
||||
// This method will also append two default certificates to the slice.
|
||||
func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
|
||||
func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
|
||||
|
||||
// Get the list `cert_name`s of all the registered certificates.
|
||||
registeredCerts, err := s.List("config/certificate/")
|
||||
|
@ -144,7 +144,7 @@ func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
|
|||
|
||||
// Iterate through each certificate, parse and append it to a slice.
|
||||
for _, cert := range registeredCerts {
|
||||
certEntry, err := awsPublicCertificateEntry(s, cert)
|
||||
certEntry, err := b.awsPublicCertificateEntry(s, cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
|
|||
|
||||
// awsPublicCertificate is used to get the configured AWS Public Key that is used
|
||||
// to verify the PKCS#7 signature of the instance identity document.
|
||||
func awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
func (b *backend) awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
entry, err := s.Get("config/certificate/" + certName)
|
||||
|
@ -213,7 +213,7 @@ func (b *backend) pathConfigCertificateRead(
|
|||
return logical.ErrorResponse("missing cert_name"), nil
|
||||
}
|
||||
|
||||
certificateEntry, err := awsPublicCertificateEntry(req.Storage, certName)
|
||||
certificateEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(
|
|||
}
|
||||
|
||||
// Check if there is already a certificate entry registered.
|
||||
certEntry, err := awsPublicCertificateEntry(req.Storage, certName)
|
||||
certEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package aws
|
|||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
@ -115,8 +116,9 @@ func (b *backend) pathImageTagUpdate(
|
|||
return logical.ErrorResponse("max_ttl cannot be negative"), nil
|
||||
}
|
||||
|
||||
// Attach version, nonce, policies and maxTTL to the role tag value.
|
||||
rTagValue, err := prepareRoleTagPlainValue(&roleTag{Version: roleTagVersion,
|
||||
// Create a role tag out of all the information provided.
|
||||
rTagValue, err := createRoleTagValue(req.Storage, &roleTag{
|
||||
Version: roleTagVersion,
|
||||
AmiID: amiID,
|
||||
Nonce: nonce,
|
||||
Policies: policies,
|
||||
|
@ -128,24 +130,8 @@ func (b *backend) pathImageTagUpdate(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Get the key used for creating the HMAC
|
||||
key, err := hmacKey(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the HMAC of the value
|
||||
hmacB64, err := createRoleTagHMACBase64(key, rTagValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// attach the HMAC to the value
|
||||
rTagValue = fmt.Sprintf("%s:%s", rTagValue, hmacB64)
|
||||
if len(rTagValue) > 255 {
|
||||
return nil, fmt.Errorf("role tag 'value' exceeding the limit of 255 characters")
|
||||
}
|
||||
|
||||
// Return the key to be used for the tag and the value to be used for that tag key.
|
||||
// This key value pair should be set on the EC2 instance.
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"tag_key": imageEntry.RoleTag,
|
||||
|
@ -154,12 +140,48 @@ func (b *backend) pathImageTagUpdate(
|
|||
}, nil
|
||||
}
|
||||
|
||||
// createRoleTagValue prepares the plaintext version of the role tag,
|
||||
// and appends a HMAC of the plaintext value to it, before returning.
|
||||
func createRoleTagValue(s logical.Storage, rTag *roleTag) (string, error) {
|
||||
// Attach version, nonce, policies and maxTTL to the role tag value.
|
||||
rTagPlainText, err := prepareRoleTagPlaintextValue(rTag)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return appendHMAC(s, rTagPlainText)
|
||||
}
|
||||
|
||||
// Takes in the plaintext part of the role tag, creates a HMAC of it and returns
|
||||
// a role tag value containing both the plaintext part and the HMAC part.
|
||||
func appendHMAC(s logical.Storage, rTagPlainText string) (string, error) {
|
||||
// Get the key used for creating the HMAC
|
||||
key, err := hmacKey(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the HMAC of the value
|
||||
hmacB64, err := createRoleTagHMACBase64(key, rTagPlainText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// attach the HMAC to the value
|
||||
rTagValue := fmt.Sprintf("%s:%s", rTagPlainText, hmacB64)
|
||||
if len(rTagValue) > 255 {
|
||||
return "", fmt.Errorf("role tag 'value' exceeding the limit of 255 characters")
|
||||
}
|
||||
|
||||
return rTagValue, nil
|
||||
}
|
||||
|
||||
// verifyRoleTagValue rebuilds the role tag value without the HMAC,
|
||||
// computes the HMAC from it using the backend specific key and
|
||||
// compares it with the received HMAC.
|
||||
func verifyRoleTagValue(s logical.Storage, rTag *roleTag) (bool, error) {
|
||||
// Fetch the plaintext part of role tag
|
||||
rTagPlainText, err := prepareRoleTagPlainValue(rTag)
|
||||
rTagPlainText, err := prepareRoleTagPlaintextValue(rTag)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -175,41 +197,30 @@ func verifyRoleTagValue(s logical.Storage, rTag *roleTag) (bool, error) {
|
|||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return rTag.HMAC == hmacB64, nil
|
||||
return subtle.ConstantTimeCompare([]byte(rTag.HMAC), []byte(hmacB64)) == 1, nil
|
||||
}
|
||||
|
||||
// prepareRoleTagPlainValue builds the role tag value without the HMAC in it.
|
||||
func prepareRoleTagPlainValue(rTag *roleTag) (string, error) {
|
||||
// prepareRoleTagPlaintextValue builds the role tag value without the HMAC in it.
|
||||
func prepareRoleTagPlaintextValue(rTag *roleTag) (string, error) {
|
||||
if rTag.Version == "" {
|
||||
return "", fmt.Errorf("missing version")
|
||||
}
|
||||
// attach version to the value
|
||||
value := rTag.Version
|
||||
|
||||
if rTag.Nonce == "" {
|
||||
return "", fmt.Errorf("missing nonce")
|
||||
}
|
||||
// attach nonce to the value
|
||||
value = fmt.Sprintf("%s:%s", value, rTag.Nonce)
|
||||
|
||||
if rTag.AmiID == "" {
|
||||
return "", fmt.Errorf("missing ami_id")
|
||||
}
|
||||
// attach ami_id to the value
|
||||
value = fmt.Sprintf("%s:a=%s", value, rTag.AmiID)
|
||||
|
||||
// attach policies to value. rTag.Policies will never be empty.
|
||||
value = fmt.Sprintf("%s:p=%s", value, strings.Join(rTag.Policies, ","))
|
||||
// Attach Version, Nonce, AMI ID, Policies, DisallowReauthentication fields.
|
||||
value := fmt.Sprintf("%s:%s:a=%s:p=%s:d=%s", rTag.Version, rTag.Nonce, rTag.AmiID, strings.Join(rTag.Policies, ","), strconv.FormatBool(rTag.DisallowReauthentication))
|
||||
|
||||
// attach disallow_reauthentication field
|
||||
value = fmt.Sprintf("%s:d=%s", value, strconv.FormatBool(rTag.DisallowReauthentication))
|
||||
|
||||
// attach instance_id if set
|
||||
// Attach instance_id if set.
|
||||
if rTag.InstanceID != "" {
|
||||
value = fmt.Sprintf("%s:i=%s", value, rTag.InstanceID)
|
||||
}
|
||||
|
||||
// attach max_ttl if it is provided
|
||||
// Attach max_ttl if it is provided.
|
||||
if rTag.MaxTTL > time.Duration(0) {
|
||||
value = fmt.Sprintf("%s:t=%s", value, rTag.MaxTTL)
|
||||
}
|
||||
|
|
|
@ -8,10 +8,10 @@ import (
|
|||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
)
|
||||
|
||||
func pathLogin(b *backend) *framework.Path {
|
||||
|
@ -116,7 +116,7 @@ func validateMetadata(clientNonce, pendingTime string, storedIdentity *whitelist
|
|||
// Verifies the correctness of the authenticated attributes present in the PKCS#7
|
||||
// signature. After verification, extracts the instance identity document from the
|
||||
// signature, parses it and returns it.
|
||||
func parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
|
||||
func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
|
||||
pkcs7B64 = fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7B64)
|
||||
|
||||
// Decode the PEM encoded signature.
|
||||
|
@ -132,7 +132,7 @@ func parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocumen
|
|||
}
|
||||
|
||||
// Get the public certificate that is used to verify the signature.
|
||||
publicCerts, err := awsPublicCertificates(s)
|
||||
publicCerts, err := b.awsPublicCertificates(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ func (b *backend) pathLoginUpdate(
|
|||
}
|
||||
|
||||
// Verify the signature of the identity document.
|
||||
identityDoc, err := parseIdentityDocument(req.Storage, pkcs7B64)
|
||||
identityDoc, err := b.parseIdentityDocument(req.Storage, pkcs7B64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue