Rework and refactoring

This commit is contained in:
vishalnayak 2016-04-19 17:07:06 -04:00
parent 3aeae62c00
commit 5996c3e9d8
3 changed files with 60 additions and 49 deletions

View file

@ -91,7 +91,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(req *logical.Request, data
if certName == "" {
return false, fmt.Errorf("missing cert_name")
}
entry, err := awsPublicCertificateEntry(req.Storage, certName)
entry, err := b.awsPublicCertificateEntry(req.Storage, certName)
if err != nil {
return false, err
}
@ -132,7 +132,7 @@ func decodePEMAndParseCertificate(certificate string) (*x509.Certificate, error)
// awsPublicCertificates returns a slice of all the parsed AWS public
// certificates, that were registered using `config/certificate/<cert_name>` endpoint.
// This method will also append two default certificates to the slice.
func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
// Get the list `cert_name`s of all the registered certificates.
registeredCerts, err := s.List("config/certificate/")
@ -144,7 +144,7 @@ func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
// Iterate through each certificate, parse and append it to a slice.
for _, cert := range registeredCerts {
certEntry, err := awsPublicCertificateEntry(s, cert)
certEntry, err := b.awsPublicCertificateEntry(s, cert)
if err != nil {
return nil, err
}
@ -170,7 +170,7 @@ func awsPublicCertificates(s logical.Storage) ([]*x509.Certificate, error) {
// awsPublicCertificate is used to get the configured AWS Public Key that is used
// to verify the PKCS#7 signature of the instance identity document.
func awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
func (b *backend) awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
entry, err := s.Get("config/certificate/" + certName)
@ -213,7 +213,7 @@ func (b *backend) pathConfigCertificateRead(
return logical.ErrorResponse("missing cert_name"), nil
}
certificateEntry, err := awsPublicCertificateEntry(req.Storage, certName)
certificateEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
if err != nil {
return nil, err
}
@ -239,7 +239,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(
}
// Check if there is already a certificate entry registered.
certEntry, err := awsPublicCertificateEntry(req.Storage, certName)
certEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
if err != nil {
return nil, err
}

View file

@ -3,6 +3,7 @@ package aws
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"strconv"
@ -115,8 +116,9 @@ func (b *backend) pathImageTagUpdate(
return logical.ErrorResponse("max_ttl cannot be negative"), nil
}
// Attach version, nonce, policies and maxTTL to the role tag value.
rTagValue, err := prepareRoleTagPlainValue(&roleTag{Version: roleTagVersion,
// Create a role tag out of all the information provided.
rTagValue, err := createRoleTagValue(req.Storage, &roleTag{
Version: roleTagVersion,
AmiID: amiID,
Nonce: nonce,
Policies: policies,
@ -128,24 +130,8 @@ func (b *backend) pathImageTagUpdate(
return nil, err
}
// Get the key used for creating the HMAC
key, err := hmacKey(req.Storage)
if err != nil {
return nil, err
}
// Create the HMAC of the value
hmacB64, err := createRoleTagHMACBase64(key, rTagValue)
if err != nil {
return nil, err
}
// attach the HMAC to the value
rTagValue = fmt.Sprintf("%s:%s", rTagValue, hmacB64)
if len(rTagValue) > 255 {
return nil, fmt.Errorf("role tag 'value' exceeding the limit of 255 characters")
}
// Return the key to be used for the tag and the value to be used for that tag key.
// This key value pair should be set on the EC2 instance.
return &logical.Response{
Data: map[string]interface{}{
"tag_key": imageEntry.RoleTag,
@ -154,12 +140,48 @@ func (b *backend) pathImageTagUpdate(
}, nil
}
// createRoleTagValue prepares the plaintext version of the role tag,
// and appends a HMAC of the plaintext value to it, before returning.
func createRoleTagValue(s logical.Storage, rTag *roleTag) (string, error) {
// Attach version, nonce, policies and maxTTL to the role tag value.
rTagPlainText, err := prepareRoleTagPlaintextValue(rTag)
if err != nil {
return "", err
}
return appendHMAC(s, rTagPlainText)
}
// Takes in the plaintext part of the role tag, creates a HMAC of it and returns
// a role tag value containing both the plaintext part and the HMAC part.
func appendHMAC(s logical.Storage, rTagPlainText string) (string, error) {
// Get the key used for creating the HMAC
key, err := hmacKey(s)
if err != nil {
return "", err
}
// Create the HMAC of the value
hmacB64, err := createRoleTagHMACBase64(key, rTagPlainText)
if err != nil {
return "", err
}
// attach the HMAC to the value
rTagValue := fmt.Sprintf("%s:%s", rTagPlainText, hmacB64)
if len(rTagValue) > 255 {
return "", fmt.Errorf("role tag 'value' exceeding the limit of 255 characters")
}
return rTagValue, nil
}
// verifyRoleTagValue rebuilds the role tag value without the HMAC,
// computes the HMAC from it using the backend specific key and
// compares it with the received HMAC.
func verifyRoleTagValue(s logical.Storage, rTag *roleTag) (bool, error) {
// Fetch the plaintext part of role tag
rTagPlainText, err := prepareRoleTagPlainValue(rTag)
rTagPlainText, err := prepareRoleTagPlaintextValue(rTag)
if err != nil {
return false, err
}
@ -175,41 +197,30 @@ func verifyRoleTagValue(s logical.Storage, rTag *roleTag) (bool, error) {
if err != nil {
return false, err
}
return rTag.HMAC == hmacB64, nil
return subtle.ConstantTimeCompare([]byte(rTag.HMAC), []byte(hmacB64)) == 1, nil
}
// prepareRoleTagPlainValue builds the role tag value without the HMAC in it.
func prepareRoleTagPlainValue(rTag *roleTag) (string, error) {
// prepareRoleTagPlaintextValue builds the role tag value without the HMAC in it.
func prepareRoleTagPlaintextValue(rTag *roleTag) (string, error) {
if rTag.Version == "" {
return "", fmt.Errorf("missing version")
}
// attach version to the value
value := rTag.Version
if rTag.Nonce == "" {
return "", fmt.Errorf("missing nonce")
}
// attach nonce to the value
value = fmt.Sprintf("%s:%s", value, rTag.Nonce)
if rTag.AmiID == "" {
return "", fmt.Errorf("missing ami_id")
}
// attach ami_id to the value
value = fmt.Sprintf("%s:a=%s", value, rTag.AmiID)
// attach policies to value. rTag.Policies will never be empty.
value = fmt.Sprintf("%s:p=%s", value, strings.Join(rTag.Policies, ","))
// Attach Version, Nonce, AMI ID, Policies, DisallowReauthentication fields.
value := fmt.Sprintf("%s:%s:a=%s:p=%s:d=%s", rTag.Version, rTag.Nonce, rTag.AmiID, strings.Join(rTag.Policies, ","), strconv.FormatBool(rTag.DisallowReauthentication))
// attach disallow_reauthentication field
value = fmt.Sprintf("%s:d=%s", value, strconv.FormatBool(rTag.DisallowReauthentication))
// attach instance_id if set
// Attach instance_id if set.
if rTag.InstanceID != "" {
value = fmt.Sprintf("%s:i=%s", value, rTag.InstanceID)
}
// attach max_ttl if it is provided
// Attach max_ttl if it is provided.
if rTag.MaxTTL > time.Duration(0) {
value = fmt.Sprintf("%s:t=%s", value, rTag.MaxTTL)
}

View file

@ -8,10 +8,10 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/fullsailor/pkcs7"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/fullsailor/pkcs7"
)
func pathLogin(b *backend) *framework.Path {
@ -116,7 +116,7 @@ func validateMetadata(clientNonce, pendingTime string, storedIdentity *whitelist
// Verifies the correctness of the authenticated attributes present in the PKCS#7
// signature. After verification, extracts the instance identity document from the
// signature, parses it and returns it.
func parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
pkcs7B64 = fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7B64)
// Decode the PEM encoded signature.
@ -132,7 +132,7 @@ func parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocumen
}
// Get the public certificate that is used to verify the signature.
publicCerts, err := awsPublicCertificates(s)
publicCerts, err := b.awsPublicCertificates(s)
if err != nil {
return nil, err
}
@ -177,7 +177,7 @@ func (b *backend) pathLoginUpdate(
}
// Verify the signature of the identity document.
identityDoc, err := parseIdentityDocument(req.Storage, pkcs7B64)
identityDoc, err := b.parseIdentityDocument(req.Storage, pkcs7B64)
if err != nil {
return nil, err
}